[Large PR] Entire rework of pipelines. (#13308)

* Enabling dataset iteration on pipelines.

Enabling dataset iteration on pipelines.

Unifying parameters under `set_parameters` function.

Small fix.

Last fixes after rebase

Remove print.

Fixing text2text `generate_kwargs`

No more `self.max_length`.

Fixing tf only conversational.

Consistency in start/stop index over TF/PT.

Speeding up drastically on TF (nasty bug where max_length would increase
a ton.)

Adding test for support for non fast tokenizers.

Fixign GPU usage on zero-shot.

Fix working on Tf.

Update src/transformers/pipelines/base.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Update src/transformers/pipelines/base.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Small cleanup.

Remove all asserts + simple format.

* Fixing audio-classification for large PR.

* Overly explicity null checking.

* Encapsulating GPU/CPU pytorch manipulation directly within `base.py`.

* Removed internal state for parameters of the  pipeline.

Instead of overriding implicitly internal state, we moved
to real named arguments on every `preprocess`, `_forward`,
`postprocess` function.

Instead `_sanitize_parameters` will be used to split all kwargs
of both __init__ and __call__ into the 3 kinds of named parameters.

* Move import warnings.

* Small fixes.

* Quality.

* Another small fix, using the CI to debug faster.

* Last fixes.

* Last fix.

* Small cleanup of tensor moving.

* is not None.

* Adding a bunch of docs + a iteration test.

* Fixing doc style.

* KeyDataset = None guard.

* RRemoving the Cuda test for pipelines (was testing).

* Even more simple iteration test.

* Correct import .

* Long day.

* Fixes in docs.

* [WIP] migrating object detection.

* Fixed the target_size bug.

* Fixup.

* Bad variable name.

* Fixing `ensure_on_device` respects original ModelOutput.
This commit is contained in:
Nicolas Patry 2021-09-10 14:47:48 +02:00 committed by GitHub
parent 09549aa18c
commit c63fcabfe9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 1559 additions and 1290 deletions

View File

@ -0,0 +1,143 @@
..
Copyright 2020 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
How to add a pipeline to 🤗 Transformers?
=======================================================================================================================
First and foremost, you need to decide the raw entries the pipeline will be able to take. It can be strings, raw bytes,
dictionnaries or whatever seems to be the most likely desired input. Try to keep these inputs as pure Python as
possible as it makes compatibility easier (even through other languages via JSON). Those will be the :obj:`inputs` of
the pipeline (:obj:`preprocess`).
Then define the :obj:`outputs`. Same policy as the :obj:`inputs`. The simpler, the better. Those will be the outputs of
:obj:`postprocess` method.
Start by inheriting the base class :obj:`Pipeline`. with the 4 methods needed to implement :obj:`preprocess`,
:obj:`_forward`, :obj:`postprocess` and :obj:`_sanitize_parameters`.
.. code-block::
from transformers import Pipeline
class MyPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs)
preprocess_kwargs = {}
if "maybe_arg" in kwargs:
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
return preprocess_kwargs, {}, {}
def preprocess(self, inputs, maybe_arg=2)
model_input = Tensor(....)
return {"model_input": model_input}
def _forward(self, model_inputs)
# model_inputs == {"model_input": model_input}
oututs = self.model(**model_inputs)
# Maybe {"logits": Tensor(...)}
return outputs
def postprocess(self, model_outputs)
best_class = model_outputs["logits"].softmax(-1)
return best_class
The structure of this breakdown is to support relatively seemless support for CPU/GPU, while supporting doing
pre/postprocessing on the CPU on different threads
:obj:`preprocess` will take the original defined inputs, and turn them something feedable to the model. It might
contain more information and is usally a :obj:`Dict`.
:obj:`_forward` is the implementation detail and is not meant to be called directly :obj:`forward` is the preferred
called method as it contains safeguards to make sure everything is working on the expected device. If anything is
linked to a real model it belongs in the :obj:`_forward` method, anything else is in the preprocess/postrocess.
:obj:`postprocess` methods will take the output of :obj:`_forward` and turn it into the final output that were decided
earlier.
:obj:`_sanitize_parameters` exists to allow users to pass any parameters whenever they wish, be it at initialization
time ``pipeline(...., maybe_arg=4)`` or at call time ``pipe = pipeline(...); output = pipe(...., maybe_arg=4)``.
The returns of :obj:`_sanitize_parameters` are the 3 dicts of kwargs that will be passed directly to :obj:`preprocess`,
:obj:`_forward` and :obj:`postprocess`. Don't fill anything if the caller didn't call with any extra parameter. That
allows to keep the default arguments in the function definition which is always more "natural".
A classic example would be a :obj:`top_k` argument in the post processing in classification tasks.
.. code-block::
>>> pipe = pipeline("my-new-task")
>>> pipe("This is a test")
[{"label": "1-star", "score": 0.8}, {"label": "2-star", "score": 0.1}, {"label": "3-star", "score": 0.05}
{"label": "4-star", "score": 0.025}, {"label": "5-star", "score": 0.025}]
>>> pipe("This is a test", top_k=2)
[{"label": "1-star", "score": 0.8}, {"label": "2-star", "score": 0.1}]
In order to achieve that, we'll update our :obj:`postprocess` method with a default parameter to :obj:`5`. and edit
:obj:`_sanitize_parameters` to allow this new parameter.
.. code-block::
def postprocess(self, model_outputs, top_k=5)
best_class = model_outputs["logits"].softmax(-1)
# Add logic to handle top_k
return best_class
def _sanitize_parameters(self, **kwargs)
preprocess_kwargs = {}
if "maybe_arg" in kwargs:
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
postprocess_kwargs = {}
if "top_k" in kwargs:
preprocess_kwargs["top_k"] = kwargs["top_k"]
return preprocess_kwargs, {}, postprocess_kwargs
Try to keep the inputs/outputs very simple and ideally JSON-serializable as it makes the pipeline usage very easy
without requiring users to understand new kind of objects. It's also relatively common to support many different types
of arguments for ease of use (audio files, can be filenames, URLs or pure bytes)
Adding it to the list of supported tasks
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Go to ``src/transformers/pipelines/__init__.py`` and fill in :obj:`SUPPORTED_TASKS` with your newly created pipeline.
If possible it should provide a default model.
Adding tests
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Create a new file ``tests/test_pipelines_MY_PIPELINE.py`` with example with the other tests.
The :obj:`run_pipeline_test` function will be very generic and run on small random models on every possible
architecture as defined by :obj:`model_mapping` and :obj:`tf_model_mapping`.
This is very important to test future compatibilty, meaning if someone adds a new model for
:obj:`XXXForQuestionAnswering` then the pipeline test will attempt to run on it. Because the models are random it's
impossible to check for actual values, that's why There is a helper :obj:`ANY` that will simply attempt to match the
output of the pipeline TYPE.
You also *need* to implement 2 (ideally 4) tests.
- :obj:`test_small_model_pt` : Define 1 small model for this pipeline (doesn't matter if the results don't make sense)
and test the pipeline outputs. The results should be the same as :obj:`test_small_model_tf`.
- :obj:`test_small_model_tf` : Define 1 small model for this pipeline (doesn't matter if the results don't make sense)
and test the pipeline outputs. The results should be the same as :obj:`test_small_model_pt`.
- :obj:`test_large_model_pt` (:obj:`optional`): Tests the pipeline on a real pipeline where the results are supposed to
make sense. These tests are slow and should be marked as such. Here the goal is to showcase the pipeline and to make
sure there is no drift in future releases
- :obj:`test_large_model_tf` (:obj:`optional`): Tests the pipeline on a real pipeline where the results are supposed to
make sense. These tests are slow and should be marked as such. Here the goal is to showcase the pipeline and to make
sure there is no drift in future releases

View File

@ -506,6 +506,7 @@ Flax), PyTorch, and/or TensorFlow.
migration
contributing
add_new_model
add_new_pipeline
fast_tokenizers
performance
parallelism

View File

@ -46,12 +46,53 @@ The pipeline abstraction
The `pipeline` abstraction is a wrapper around all the other available pipelines. It is instantiated as any other
pipeline but requires an additional argument which is the `task`.
Simple call on one item:
.. code-block::
>>> pipe = pipeline("text-classification")
>>> pipe("This restaurant is awesome")
[{'label': 'POSITIVE', 'score': 0.9998743534088135}]
To call a pipeline on many items, you can either call with a `list`.
.. code-block::
>>> pipe = pipeline("text-classification")
>>> pipe(["This restaurant is awesome", "This restaurant is aweful"])
[{'label': 'POSITIVE', 'score': 0.9998743534088135},
{'label': 'NEGATIVE', 'score': 0.9996669292449951}]
To iterate of full datasets it is recommended to use a :obj:`dataset` directly. This means you don't need to allocate
the whole dataset at once, nor do you need to do batching yourself. This should work just as fast as custom loops on
GPU. If it doesn't don't hesitate to create an issue.
.. code-block::
pipe = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h", device=0)
dataset = datasets.load_dataset("superb", name="asr", split="test")
# KeyDataset (only `pt`) will simply return the item in the dict returned by the dataset item
# as we're not interested in the `target` part of the dataset.
for out in tqdm.tqdm(pipe(KeyDataset(dataset, "file"))):
print(out)
# {"text": "NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD NIGHT HUSBAND"}
# {"text": ....}
# ....
.. autofunction:: transformers.pipeline
Implementing a pipeline
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
:doc:`Implementing a new pipeline <../add_new_pipeline>`
The task specific pipelines
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
AudioClassificationPipeline
=======================================================================================================================

View File

@ -67,8 +67,8 @@ make them readable. For instance:
>>> classifier('We are very happy to show you the 🤗 Transformers library.')
[{'label': 'POSITIVE', 'score': 0.9998}]
That's encouraging! You can use it on a list of sentences, which will be preprocessed then fed to the model as a
`batch`, returning a list of dictionaries like this one:
That's encouraging! You can use it on a list of sentences, which will be preprocessed then fed to the model, returning
a list of dictionaries like this one:
.. code-block::
@ -79,6 +79,8 @@ That's encouraging! You can use it on a list of sentences, which will be preproc
label: POSITIVE, with score: 0.9998
label: NEGATIVE, with score: 0.5309
To use with a large dataset, look at :doc:`iterating over a pipeline <./main_classes/pipelines>`
You can see the second sentence has been classified as negative (it needs to be positive or negative) but its score is
fairly neutral.

View File

@ -249,18 +249,22 @@ def check_task(task: str) -> Tuple[Dict, Any]:
task (:obj:`str`):
The task defining which pipeline will be returned. Currently accepted tasks are:
- :obj:`"feature-extraction"`
- :obj:`"text-classification"`
- :obj:`"sentiment-analysis"` (alias of :obj:`"text-classification")
- :obj:`"token-classification"`
- :obj:`"ner"` (alias of :obj:`"token-classification")
- :obj:`"question-answering"`
- :obj:`"fill-mask"`
- :obj:`"summarization"`
- :obj:`"translation_xx_to_yy"`
- :obj:`"translation"`
- :obj:`"text-generation"`
- :obj:`"audio-classification"`
- :obj:`"automatic-speech-recognition"`
- :obj:`"conversational"`
- :obj:`"feature-extraction"`
- :obj:`"fill-mask"`
- :obj:`"image-classification"`
- :obj:`"question-answering"`
- :obj:`"table-question-answering"`
- :obj:`"text2text-generation"`
- :obj:`"text-classification"` (alias :obj:`"sentiment-analysis" available)
- :obj:`"text-generation"`
- :obj:`"token-classification"` (alias :obj:`"ner"` available)
- :obj:`"translation"`
- :obj:`"translation_xx_to_yy"`
- :obj:`"summarization"`
- :obj:`"zero-shot-classification"`
Returns:
(task_defaults:obj:`dict`, task_options: (:obj:`tuple`, None)) The actual dictionary required to initialize the
@ -312,21 +316,26 @@ def pipeline(
task (:obj:`str`):
The task defining which pipeline will be returned. Currently accepted tasks are:
- :obj:`"feature-extraction"`: will return a :class:`~transformers.FeatureExtractionPipeline`.
- :obj:`"text-classification"`: will return a :class:`~transformers.TextClassificationPipeline`.
- :obj:`"sentiment-analysis"`: (alias of :obj:`"text-classification"`) will return a
:class:`~transformers.TextClassificationPipeline`.
- :obj:`"token-classification"`: will return a :class:`~transformers.TokenClassificationPipeline`.
- :obj:`"ner"` (alias of :obj:`"token-classification"`): will return a
:class:`~transformers.TokenClassificationPipeline`.
- :obj:`"question-answering"`: will return a :class:`~transformers.QuestionAnsweringPipeline`.
- :obj:`"fill-mask"`: will return a :class:`~transformers.FillMaskPipeline`.
- :obj:`"summarization"`: will return a :class:`~transformers.SummarizationPipeline`.
- :obj:`"translation_xx_to_yy"`: will return a :class:`~transformers.TranslationPipeline`.
- :obj:`"text2text-generation"`: will return a :class:`~transformers.Text2TextGenerationPipeline`.
- :obj:`"text-generation"`: will return a :class:`~transformers.TextGenerationPipeline`.
- :obj:`"zero-shot-classification:`: will return a :class:`~transformers.ZeroShotClassificationPipeline`.
- :obj:`"conversational"`: will return a :class:`~transformers.ConversationalPipeline`.
- :obj:`"audio-classification"`: will return a :class:`~transformers.AudioClassificationPipeline`:.
- :obj:`"automatic-speech-recognition"`: will return a
:class:`~transformers.AutomaticSpeechRecognitionPipeline`:.
- :obj:`"conversational"`: will return a :class:`~transformers.ConversationalPipeline`:.
- :obj:`"feature-extraction"`: will return a :class:`~transformers.FeatureExtractionPipeline`:.
- :obj:`"fill-mask"`: will return a :class:`~transformers.FillMaskPipeline`:.
- :obj:`"image-classification"`: will return a :class:`~transformers.ImageClassificationPipeline`:.
- :obj:`"question-answering"`: will return a :class:`~transformers.QuestionAnsweringPipeline`:.
- :obj:`"table-question-answering"`: will return a :class:`~transformers.TableQuestionAnsweringPipeline`:.
- :obj:`"text2text-generation"`: will return a :class:`~transformers.Text2TextGenerationPipeline`:.
- :obj:`"text-classification"` (alias :obj:`"sentiment-analysis" available): will return a
:class:`~transformers.TextClassificationPipeline`:.
- :obj:`"text-generation"`: will return a :class:`~transformers.TextGenerationPipeline`:.
- :obj:`"token-classification"` (alias :obj:`"ner"` available): will return a
:class:`~transformers.TokenClassificationPipeline`:.
- :obj:`"translation"`: will return a :class:`~transformers.TranslationPipeline`:.
- :obj:`"translation_xx_to_yy"`: will return a :class:`~transformers.TranslationPipeline`:.
- :obj:`"summarization"`: will return a :class:`~transformers.SummarizationPipeline`:.
- :obj:`"zero-shot-classification"`: will return a :class:`~transformers.ZeroShotClassificationPipeline`:.
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`):
The model that will be used by the pipeline to make predictions. This can be a model identifier or an
actual instance of a pretrained model inheriting from :class:`~transformers.PreTrainedModel` (for PyTorch)

View File

@ -12,23 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import subprocess
from typing import TYPE_CHECKING, Optional, Union
from typing import Union
import numpy as np
from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..file_utils import add_end_docstrings, is_torch_available
from ..utils import logging
from .base import PIPELINE_INIT_ARGS, Pipeline
if TYPE_CHECKING:
from ..modeling_tf_utils import TFPreTrainedModel
from ..modeling_utils import PreTrainedModel
if is_torch_available():
import torch
from ..models.auto.modeling_auto import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
logger = logging.get_logger(__name__)
@ -84,14 +77,10 @@ class AudioClassificationPipeline(Pipeline):
<https://huggingface.co/models?filter=audio-classification>`__.
"""
def __init__(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel"],
feature_extractor: PreTrainedFeatureExtractor,
framework: Optional[str] = None,
**kwargs
):
super().__init__(model, feature_extractor=feature_extractor, framework=framework, **kwargs)
def __init__(self, *args, **kwargs):
# Default, might be overriden by the model.config.
kwargs["top_k"] = 5
super().__init__(*args, **kwargs)
if self.framework != "pt":
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
@ -101,7 +90,6 @@ class AudioClassificationPipeline(Pipeline):
def __call__(
self,
inputs: Union[np.ndarray, bytes, str],
top_k: Optional[int] = None,
**kwargs,
):
"""
@ -126,6 +114,18 @@ class AudioClassificationPipeline(Pipeline):
- **label** (:obj:`str`) -- The label predicted.
- **score** (:obj:`float`) -- The corresponding probability.
"""
return super().__call__(inputs, **kwargs)
def _sanitize_parameters(self, top_k=None, **kwargs):
# No parameters on this pipeline right now
postprocess_params = {}
if top_k is not None:
if top_k > self.model.config.num_labels:
top_k = self.model.config.num_labels
postprocess_params["top_k"] = top_k
return {}, {}, postprocess_params
def preprocess(self, inputs):
if isinstance(inputs, str):
with open(inputs, "rb") as f:
inputs = f.read()
@ -136,24 +136,23 @@ class AudioClassificationPipeline(Pipeline):
if not isinstance(inputs, np.ndarray):
raise ValueError("We expect a numpy ndarray as input")
if len(inputs.shape) != 1:
raise ValueError("We expect a single channel audio input for AudioClassificationPipeline")
if top_k is None or top_k > self.model.config.num_labels:
top_k = self.model.config.num_labels
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
processed = self.ensure_tensor_on_device(**processed)
return processed
with torch.no_grad():
outputs = self.model(**processed)
def _forward(self, model_inputs):
model_outputs = self.model(**model_inputs)
return model_outputs
probs = outputs.logits[0].softmax(-1)
scores, ids = probs.topk(top_k)
def postprocess(self, model_outputs, top_k=5):
probs = model_outputs.logits[0].softmax(-1)
scores, ids = probs.topk(top_k)
scores = scores.tolist()
ids = ids.tolist()
scores = scores.tolist()
ids = ids.tolist()
labels = [{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]

View File

@ -124,6 +124,13 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
- **text** (:obj:`str`) -- The recognized text.
"""
return super().__call__(inputs, **kwargs)
def _sanitize_parameters(self, **kwargs):
# No parameters on this pipeline right now
return {}, {}, {}
def preprocess(self, inputs):
if isinstance(inputs, str):
with open(inputs, "rb") as f:
inputs = f.read()
@ -131,27 +138,34 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
if isinstance(inputs, bytes):
inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
assert isinstance(inputs, np.ndarray), "We expect a numpy ndarray as input"
assert len(inputs.shape) == 1, "We expect a single channel audio input for AutomaticSpeechRecognitionPipeline"
if not isinstance(inputs, np.ndarray):
raise ValueError("We expect a numpy ndarray as input")
if len(inputs.shape) != 1:
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
processed = self.ensure_tensor_on_device(**processed)
return processed
def _forward(self, model_inputs):
name = self.model.__class__.__name__
if name.endswith("ForConditionalGeneration") or name.endswith("EncoderDecoderModel"):
encoder = self.model.get_encoder()
# we need to pass `processed.get("attention_mask")` here since audio encoder
# attention mask length is different from expected text decoder `encoder_attention_mask` length
# `generate` magic to create the mask automatically won't work, we basically need to help
# it here.
tokens = self.model.generate(
encoder_outputs=encoder(**processed), attention_mask=processed.get("attention_mask")
encoder_outputs=encoder(**model_inputs), attention_mask=model_inputs.get("attention_mask")
)
tokens = tokens.squeeze(0)
elif name.endswith("ForCTC"):
outputs = self.model(**processed)
outputs = self.model(**model_inputs)
tokens = outputs.logits.squeeze(0).argmax(dim=-1)
return tokens
def postprocess(self, model_outputs):
skip_special_tokens = False if "CTC" in self.tokenizer.__class__.__name__ else True
recognized_string = self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
recognized_string = self.tokenizer.decode(model_outputs, skip_special_tokens=skip_special_tokens)
return {"text": recognized_string}

View File

@ -20,18 +20,21 @@ import pickle
import sys
import warnings
from abc import ABC, abstractmethod
from collections import UserDict
from contextlib import contextmanager
from os.path import abspath, exists
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
from ..file_utils import ModelOutput, add_end_docstrings, is_tf_available, is_torch_available
from ..modelcard import ModelCard
from ..models.auto.configuration_auto import AutoConfig
from ..tokenization_utils import PreTrainedTokenizer, TruncationStrategy
from ..tokenization_utils import PreTrainedTokenizer
from ..utils import logging
GenericTensor = Union[List["GenericTensor"], "torch.Tensor", "tf.Tensor"]
if is_tf_available():
import tensorflow as tf
@ -39,8 +42,12 @@ if is_tf_available():
if is_torch_available():
import torch
from torch.utils.data import DataLoader, Dataset, IterableDataset
from ..models.auto.modeling_auto import AutoModel
else:
Dataset = None
KeyDataset = None
if TYPE_CHECKING:
from ..modeling_tf_utils import TFPreTrainedModel
@ -50,6 +57,12 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
def collate_fn(items):
if len(items) != 1:
raise ValueError("This collate_fn is meant to be used with batch_size=1")
return items[0]
def infer_framework_load_model(
model,
config: AutoConfig,
@ -585,6 +598,51 @@ PIPELINE_INIT_ARGS = r"""
Flag indicating if the output the pipeline should happen in a binary format (i.e., pickle) or as raw text.
"""
if is_torch_available():
class PipelineDataset(Dataset):
def __init__(self, dataset, process, params):
self.dataset = dataset
self.process = process
self.params = params
def __len__(self):
return len(self.dataset)
def __getitem__(self, i):
item = self.dataset[i]
processed = self.process(item, **self.params)
return processed
class PipelineIterator(IterableDataset):
def __init__(self, loader, infer, params):
self.loader = loader
self.infer = infer
self.params = params
def __len__(self):
return len(self.loader)
def __iter__(self):
self.iterator = iter(self.loader)
return self
def __next__(self):
item = next(self.iterator)
processed = self.infer(item, **self.params)
return processed
class KeyDataset(Dataset):
def __init__(self, dataset: Dataset, key: str):
self.dataset = dataset
self.key = key
def __len__(self):
return len(self.dataset)
def __getitem__(self, i):
return self.dataset[i][self.key]
@add_end_docstrings(PIPELINE_INIT_ARGS)
class Pipeline(_ScikitCompat):
@ -618,6 +676,7 @@ class Pipeline(_ScikitCompat):
args_parser: ArgumentHandler = None,
device: int = -1,
binary_output: bool = False,
**kwargs,
):
if framework is None:
@ -641,6 +700,9 @@ class Pipeline(_ScikitCompat):
if task_specific_params is not None and task in task_specific_params:
self.model.config.update(task_specific_params.get(task))
self.call_count = 0
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
def save_pretrained(self, save_directory: str):
"""
Save the pipeline's model and tokenizer.
@ -707,15 +769,31 @@ class Pipeline(_ScikitCompat):
Ensure PyTorch tensors are on the specified device.
Args:
inputs (keyword arguments that should be :obj:`torch.Tensor`): The tensors to place on :obj:`self.device`.
inputs (keyword arguments that should be :obj:`torch.Tensor`, the rest is ignored): The tensors to place on :obj:`self.device`.
Recursive on lists **only**.
Return:
:obj:`Dict[str, torch.Tensor]`: The same as :obj:`inputs` but on the proper device.
"""
return {
name: tensor.to(self.device) if isinstance(tensor, torch.Tensor) else tensor
for name, tensor in inputs.items()
}
return self._ensure_tensor_on_device(inputs, self.device)
def _ensure_tensor_on_device(self, inputs, device):
if isinstance(inputs, ModelOutput):
return ModelOutput(
{name: self._ensure_tensor_on_device(tensor, device) for name, tensor in inputs.items()}
)
elif isinstance(inputs, dict):
return {name: self._ensure_tensor_on_device(tensor, device) for name, tensor in inputs.items()}
elif isinstance(inputs, UserDict):
return UserDict({name: self._ensure_tensor_on_device(tensor, device) for name, tensor in inputs.items()})
elif isinstance(inputs, list):
return [self._ensure_tensor_on_device(item, device) for item in inputs]
elif isinstance(inputs, tuple):
return tuple([self._ensure_tensor_on_device(item, device) for item in inputs])
elif isinstance(inputs, torch.Tensor):
return inputs.to(self.device)
else:
return inputs
def check_model_type(self, supported_models: Union[List[str], dict]):
"""
@ -739,65 +817,108 @@ class Pipeline(_ScikitCompat):
f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}."
)
def _parse_and_tokenize(
self, inputs, padding=True, add_special_tokens=True, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs
):
@abstractmethod
def _sanitize_parameters(self, **pipeline_parameters):
"""
Parse arguments and tokenize
_sanitize_parameters will be called with any excessive named arguments from either `__init__` or `__call__`
methods. It should return 3 dictionnaries of the resolved parameters used by the various `preprocess`,
`forward` and `postprocess` methods. Do not fill dictionnaries if the caller didn't specify a kwargs. This
let's you keep defaults in function signatures, which is more "natural".
It is not meant to be called directly, it will be automatically called and the final parameters resolved by
`__init__` and `__call__`
"""
# Parse arguments
if getattr(self.tokenizer, "pad_token", None) is None:
padding = False
inputs = self.tokenizer(
inputs,
add_special_tokens=add_special_tokens,
return_tensors=self.framework,
padding=padding,
truncation=truncation,
)
return inputs
raise NotImplementedError("_sanitize_parameters not implemented")
def __call__(self, inputs, *args, **kwargs):
try:
model_inputs = self._parse_and_tokenize(inputs, *args, **kwargs)
outputs = self._forward(model_inputs)
return outputs
except ValueError:
# XXX: Some tokenizer do NOT have a pad token, hence we cannot run the inference
# in a batch, instead we run everything sequentially
if isinstance(inputs, list):
values = []
for input_ in inputs:
model_input = self._parse_and_tokenize(input_, padding=False, *args, **kwargs)
value = self._forward(model_input)
values.append(value.squeeze(0))
else:
model_input = self._parse_and_tokenize(inputs, padding=False, *args, **kwargs)
values = self._forward(model_input)
return values
def _forward(self, inputs, return_tensors=False):
@abstractmethod
def preprocess(self, input_: Any, **preprocess_parameters: Dict) -> Dict[str, GenericTensor]:
"""
Internal framework specific forward dispatching
Args:
inputs: dict holding all the keyword arguments for required by the model forward method.
return_tensors: Whether to return native framework (pt/tf) tensors rather than numpy array
Returns:
Numpy array
Preprocess will take the `input_` of a specific pipeline and return a dictionnary of everything necessary for
`_forward` to run properly. It should contain at least one tensor, but might have arbitrary other items.
"""
# Encode for forward
raise NotImplementedError("preprocess not implemented")
@abstractmethod
def _forward(self, input_tensors: Dict[str, GenericTensor], **forward_parameters: Dict) -> ModelOutput:
"""
_forward will receive the prepared dictionnary from `preprocess` and run it on the model. This method might
involve the GPU or the CPU and should be agnostic to it. Isolating this function is the reason for `preprocess`
and `postprocess` to exist, so that the hot path, this method generally can run as fast as possible.
It is not meant to be called directly, `forward` is preferred. It is basically the same but contains additional
code surrounding `_forward` making sure tensors and models are on the same device, disabling the training part
of the code (leading to faster inference).
"""
raise NotImplementedError("_forward not implemented")
@abstractmethod
def postprocess(self, model_outputs: ModelOutput, **postprocess_parameters: Dict) -> Any:
"""
Postprocess will receive the raw outputs of the `_forward` method, generally tensors, and reformat them into
something more friendly. Generally it will output a list or a dict or results (containing just strings and
numbers).
"""
raise NotImplementedError("postprocess not implemented")
def forward(self, model_inputs, **forward_params):
with self.device_placement():
if self.framework == "tf":
# TODO trace model
predictions = self.model(inputs.data, training=False)[0]
else:
model_inputs["training"] = False
model_outputs = self._forward(model_inputs, **forward_params)
elif self.framework == "pt":
with torch.no_grad():
inputs = self.ensure_tensor_on_device(**inputs)
predictions = self.model(**inputs)[0].cpu()
model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
model_outputs = self._forward(model_inputs, **forward_params)
model_outputs = self._ensure_tensor_on_device(model_outputs, device=torch.device("cpu"))
else:
raise ValueError(f"Framework {self.framework} is not supported")
return model_outputs
if return_tensors:
return predictions
def get_iterator(self, inputs, num_workers: int, preprocess_params, forward_params, postprocess_params):
if "TOKENIZERS_PARALLELISM" not in os.environ:
logger.info("Disabling tokenizer parallelism, we're using DataLoader multithreading already")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
dataset = PipelineDataset(inputs, self.preprocess, preprocess_params)
dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=1, collate_fn=collate_fn)
model_iterator = PipelineIterator(dataloader, self.forward, forward_params)
final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params)
return final_iterator
def __call__(self, inputs, *args, num_workers=8, **kwargs):
if args:
logger.warning(f"Ignoring args : {args}")
preprocess_params, forward_params, postprocess_params = self._sanitize_parameters(**kwargs)
# Fuse __init__ params and __call__ params without modifying the __init__ ones.
preprocess_params = {**self._preprocess_params, **preprocess_params}
forward_params = {**self._forward_params, **forward_params}
postprocess_params = {**self._postprocess_params, **postprocess_params}
self.call_count += 1
if self.call_count > 10 and self.framework == "pt" and self.device.type == "cuda":
warnings.warn(
"You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset",
UserWarning,
)
if isinstance(inputs, list):
if self.framework == "pt":
final_iterator = self.get_iterator(
inputs, num_workers, preprocess_params, forward_params, postprocess_params
)
outputs = [output for output in final_iterator]
return outputs
else:
return self.run_multi(inputs, preprocess_params, forward_params, postprocess_params)
elif Dataset is not None and isinstance(inputs, Dataset):
return self.get_iterator(inputs, num_workers, preprocess_params, forward_params, postprocess_params)
else:
return predictions.numpy()
return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
def run_multi(self, inputs, preprocess_params, forward_params, postprocess_params):
return [self.run_single(item, preprocess_params, forward_params, postprocess_params) for item in inputs]
def run_single(self, inputs, preprocess_params, forward_params, postprocess_params):
model_inputs = self.preprocess(inputs, **preprocess_params)
model_outputs = self.forward(model_inputs, **forward_params)
outputs = self.postprocess(model_outputs, **postprocess_params)
return outputs

View File

@ -190,23 +190,34 @@ class ConversationalPipeline(Pipeline):
conversational_pipeline([conversation_1, conversation_2])
"""
def __init__(self, min_length_for_response=32, minimum_tokens=10, *args, **kwargs):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# We need at least an eos_token
# assert self.tokenizer.eos_token_id is not None, "ConversationalPipeline tokenizer should have an EOS token set"
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.min_length_for_response = min_length_for_response
self.minimum_tokens = minimum_tokens
def __call__(
self,
conversations: Union[Conversation, List[Conversation]],
clean_up_tokenization_spaces=True,
**generate_kwargs
def _sanitize_parameters(
self, min_length_for_response=None, minimum_tokens=None, clean_up_tokenization_spaces=None, **generate_kwargs
):
preprocess_params = {}
forward_params = {}
postprocess_params = {}
if min_length_for_response is not None:
preprocess_params["min_length_for_response"] = min_length_for_response
if minimum_tokens is not None:
forward_params["minimum_tokens"] = minimum_tokens
if "max_length" in generate_kwargs:
forward_params["max_length"] = generate_kwargs["max_length"]
# self.max_length = generate_kwargs.get("max_length", self.model.config.max_length)
if clean_up_tokenization_spaces is not None:
postprocess_params["clean_up_tokenization_spaces"] = clean_up_tokenization_spaces
if generate_kwargs:
forward_params.update(generate_kwargs)
return preprocess_params, forward_params, postprocess_params
def __call__(self, conversations: Union[Conversation, List[Conversation]], num_workers=0, **kwargs):
r"""
Generate responses for the conversation(s) given as inputs.
@ -223,117 +234,67 @@ class ConversationalPipeline(Pipeline):
:class:`~transformers.Conversation` or a list of :class:`~transformers.Conversation`: Conversation(s) with
updated generated responses for those containing a new user input.
"""
if isinstance(conversations, Conversation):
conversations = [conversations]
# Input validation
if isinstance(conversations, list):
for conversation in conversations:
assert isinstance(
conversation, Conversation
), "ConversationalPipeline expects a Conversation or list of Conversations as an input"
if conversation.new_user_input is None:
raise ValueError(
f"Conversation with UUID {type(conversation.uuid)} does not contain new user input to process. "
"Add user inputs with the conversation's `add_user_input` method"
)
assert (
self.tokenizer.pad_token_id is not None or self.tokenizer.eos_token_id is not None
), "Please make sure that the tokenizer has a pad_token_id or eos_token_id when using a batch input"
else:
raise ValueError("ConversationalPipeline expects a Conversation or list of Conversations as an input")
with self.device_placement():
inputs = self._parse_and_tokenize(conversations)
if self.framework == "pt":
inputs = self.ensure_tensor_on_device(**inputs)
input_length = inputs["input_ids"].shape[-1]
elif self.framework == "tf":
input_length = tf.shape(inputs["input_ids"])[-1].numpy()
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
n = inputs["input_ids"].shape[1]
if max_length - self.minimum_tokens < n:
logger.warning(
f"Conversation input is to long ({n}), trimming it to ({max_length} - {self.minimum_tokens})"
)
trim = max_length - self.minimum_tokens
inputs["input_ids"] = inputs["input_ids"][:, -trim:]
inputs["attention_mask"] = inputs["attention_mask"][:, -trim:]
generated_responses = self.model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
**generate_kwargs,
)
if self.model.config.is_encoder_decoder:
if self.framework == "pt":
history = torch.cat((inputs["input_ids"], generated_responses[:, 1:]), 1)
elif self.framework == "tf":
history = tf.concat([inputs["input_ids"], generated_responses[:, 1:]], 1)
else:
history = generated_responses
history = self._clean_padding_history(history)
if self.model.config.is_encoder_decoder:
start_position = 1
else:
start_position = input_length
output = []
for conversation_index, conversation in enumerate(conversations):
conversation.mark_processed()
conversation.generated_responses.append(
self.tokenizer.decode(
generated_responses[conversation_index][start_position:],
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
)
output.append(conversation)
if len(output) == 1:
return output[0]
else:
return output
def _clean_padding_history(self, generated_tensor) -> List[List[int]]:
"""
Cleans the padding history. Padding may be generated in two places when multiple conversations are provided as
an input:
- at the end of the concatenated history and new user input, so that all input to the model have the same
length
- at the end of the generated response, as some responses will be longer than others
This method cleans up these padding token so that the history for each conversation is not impacted by the
batching process.
"""
outputs = []
for sequence in generated_tensor:
sequence_tokens = []
is_previous_pad = False
for token in sequence:
if token == self.tokenizer.pad_token_id:
if self.tokenizer.pad_token_id != self.tokenizer.eos_token_id:
continue
if is_previous_pad:
continue
else:
is_previous_pad = True
else:
is_previous_pad = False
if self.framework == "pt":
sequence_tokens.append(token.item())
else:
sequence_tokens.append(int(token.numpy()))
outputs.append(sequence_tokens)
# XXX: num_workers==0 is required to be backward compatible
# Otherwise the threads will require a Conversation copy.
# This will definitely hinder performance on GPU, but has to be opted
# in because of this BC change.
outputs = super().__call__(conversations, num_workers=num_workers, **kwargs)
if isinstance(outputs, list) and len(outputs) == 1:
return outputs[0]
return outputs
def _legacy_parse_and_tokenize(self, conversation: List[Conversation]) -> List[int]:
def preprocess(self, conversation: Conversation) -> Dict[str, Any]:
if not isinstance(conversation, Conversation):
raise ValueError("ConversationalPipeline, expects Conversation as inputs")
if conversation.new_user_input is None:
raise ValueError(
f"Conversation with UUID {type(conversation.uuid)} does not contain new user input to process. "
"Add user inputs with the conversation's `add_user_input` method"
)
if hasattr(self.tokenizer, "_build_conversation_input_ids"):
input_ids = self.tokenizer._build_conversation_input_ids(conversation)
else:
# If the tokenizer cannot handle conversations, we default to only the old version
input_ids = self._legacy_parse_and_tokenize(conversation)
if self.framework == "pt":
input_ids = torch.LongTensor([input_ids])
elif self.framework == "tf":
input_ids = tf.constant([input_ids])
return {"input_ids": input_ids, "conversation": conversation}
def _forward(self, model_inputs, minimum_tokens=10, **generate_kwargs):
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
n = model_inputs["input_ids"].shape[1]
if max_length - minimum_tokens < n:
logger.warning(f"Conversation input is to long ({n}), trimming it to ({max_length} - {minimum_tokens})")
trim = max_length - minimum_tokens
model_inputs["input_ids"] = model_inputs["input_ids"][:, -trim:]
if "attention_mask" in model_inputs:
model_inputs["attention_mask"] = model_inputs["attention_mask"][:, -trim:]
conversation = model_inputs.pop("conversation")
model_inputs["max_length"] = max_length
output_ids = self.model.generate(**model_inputs, **generate_kwargs)
if self.model.config.is_encoder_decoder:
start_position = 1
else:
start_position = n
return {"output_ids": output_ids[0, start_position:], "conversation": conversation}
def postprocess(self, model_outputs, clean_up_tokenization_spaces=True):
output_ids = model_outputs["output_ids"]
answer = self.tokenizer.decode(
output_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
conversation = model_outputs["conversation"]
conversation.mark_processed()
conversation.append_response(answer)
return conversation
def _legacy_parse_and_tokenize(self, conversation: Conversation) -> Dict:
eos_token_id = self.tokenizer.eos_token_id
input_ids = []
for is_user, text in conversation.iter_texts():
@ -345,14 +306,3 @@ class ConversationalPipeline(Pipeline):
if len(input_ids) > self.tokenizer.model_max_length:
input_ids = input_ids[-self.tokenizer.model_max_length :]
return input_ids
def _parse_and_tokenize(self, conversations: List[Conversation]) -> Dict[str, Any]:
if hasattr(self.tokenizer, "_build_conversation_input_ids"):
input_ids = [self.tokenizer._build_conversation_input_ids(conversation) for conversation in conversations]
else:
# If the tokenizer cannot handle conversations, we default to only the old version
input_ids = [self._legacy_parse_and_tokenize(conversation) for conversation in conversations]
inputs = self.tokenizer.pad(
{"input_ids": input_ids}, padding="longest", return_attention_mask=True, return_tensors=self.framework
)
return inputs

View File

@ -1,14 +1,6 @@
from typing import TYPE_CHECKING, Optional, Union
from typing import Dict
from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..modelcard import ModelCard
from ..tokenization_utils import PreTrainedTokenizer
from .base import ArgumentHandler, Pipeline
if TYPE_CHECKING:
from ..modeling_tf_utils import TFPreTrainedModel
from ..modeling_utils import PreTrainedModel
from .base import GenericTensor, Pipeline
# Can't use @add_end_docstrings(PIPELINE_INIT_ARGS) here because this one does not accept `binary_output`
@ -49,28 +41,24 @@ class FeatureExtractionPipeline(Pipeline):
the associated CUDA device id.
"""
def __init__(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel"],
tokenizer: PreTrainedTokenizer,
feature_extractor: Optional[PreTrainedFeatureExtractor] = None,
modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None,
args_parser: ArgumentHandler = None,
device: int = -1,
task: str = "",
):
super().__init__(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
modelcard=modelcard,
framework=framework,
args_parser=args_parser,
device=device,
binary_output=True,
task=task,
)
def _sanitize_parameters(self, **kwargs):
return {}, {}, {}
def preprocess(self, inputs) -> Dict[str, GenericTensor]:
return_tensors = self.framework
model_inputs = self.tokenizer(inputs, return_tensors=return_tensors)
return model_inputs
def _forward(self, model_inputs):
model_outputs = self.model(**model_inputs)
return model_outputs
def postprocess(self, model_outputs):
# [0] is the first available tensor, logits or last_hidden_state.
if self.framework == "pt":
return model_outputs[0].tolist()
elif self.framework == "tf":
return model_outputs[0].numpy().tolist()
def __call__(self, *args, **kwargs):
"""
@ -82,10 +70,4 @@ class FeatureExtractionPipeline(Pipeline):
Return:
A nested list of :obj:`float`: The features computed by the model.
"""
results = super().__call__(*args, **kwargs)
if isinstance(results, list):
# Sequential run
results = [r.tolist() for r in results]
else:
results = results.tolist()
return results
return super().__call__(*args, **kwargs)

View File

@ -1,30 +1,19 @@
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from typing import Dict
import numpy as np
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
from ..modelcard import ModelCard
from ..tokenization_utils import PreTrainedTokenizer
from ..utils import logging
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline, PipelineException
from .base import PIPELINE_INIT_ARGS, GenericTensor, Pipeline, PipelineException
GenericTensor = Union[List["GenericTensor"], "torch.Tensor", "tf.Tensor"]
if TYPE_CHECKING:
from ..modeling_tf_utils import TFPreTrainedModel
from ..modeling_utils import PreTrainedModel
if is_tf_available():
import tensorflow as tf
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_MASKED_LM_MAPPING
if is_torch_available():
import torch
from ..models.auto.modeling_auto import MODEL_FOR_MASKED_LM_MAPPING
logger = logging.get_logger(__name__)
@ -58,39 +47,6 @@ class FillMaskPipeline(Pipeline):
This pipeline only works for inputs with exactly one token masked.
"""
def __init__(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel"],
tokenizer: PreTrainedTokenizer,
modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None,
args_parser: ArgumentHandler = None,
device: int = -1,
top_k=5,
targets=None,
task: str = "",
):
super().__init__(
model=model,
tokenizer=tokenizer,
modelcard=modelcard,
framework=framework,
args_parser=args_parser,
device=device,
binary_output=True,
task=task,
)
self.check_model_type(
TF_MODEL_FOR_MASKED_LM_MAPPING if self.framework == "tf" else MODEL_FOR_MASKED_LM_MAPPING
)
self.top_k = top_k
self.targets = targets
if self.tokenizer.mask_token_id is None:
raise PipelineException(
"fill-mask", self.model.base_model_prefix, "The tokenizer does not define a `mask_token`."
)
def get_masked_index(self, input_ids: GenericTensor) -> np.ndarray:
if self.framework == "tf":
masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()
@ -124,17 +80,124 @@ class FillMaskPipeline(Pipeline):
for input_ids in model_inputs["input_ids"]:
self._ensure_exactly_one_mask_token(input_ids)
def get_model_inputs(self, inputs, *args, **kwargs) -> Dict:
if isinstance(inputs, list) and self.tokenizer.pad_token is None:
model_inputs = []
for input_ in inputs:
model_input = self._parse_and_tokenize(input_, padding=False, *args, **kwargs)
model_inputs.append(model_input)
else:
model_inputs = self._parse_and_tokenize(inputs, *args, **kwargs)
def preprocess(self, inputs, return_tensors=None, **preprocess_parameters) -> Dict[str, GenericTensor]:
if return_tensors is None:
return_tensors = self.framework
model_inputs = self.tokenizer(inputs, return_tensors=return_tensors)
self.ensure_exactly_one_mask_token(model_inputs)
return model_inputs
def __call__(self, inputs, *args, targets=None, top_k: Optional[int] = None, **kwargs):
def _forward(self, model_inputs):
model_outputs = self.model(**model_inputs)
model_outputs["input_ids"] = model_inputs["input_ids"][0]
return model_outputs
def postprocess(self, model_outputs, top_k=5, target_ids=None):
# Cap top_k if there are targets
if target_ids is not None and target_ids.shape[0] < top_k:
top_k = target_ids.shape[0]
input_ids = model_outputs["input_ids"]
outputs = model_outputs["logits"]
result = []
if self.framework == "tf":
masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()
# Fill mask pipeline supports only one ${mask_token} per sample
logits = outputs[0, masked_index.item(), :]
probs = tf.nn.softmax(logits)
if target_ids is not None:
probs = tf.gather_nd(probs, tf.reshape(target_ids, (-1, 1)))
topk = tf.math.top_k(probs, k=top_k)
values, predictions = topk.values.numpy(), topk.indices.numpy()
else:
masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False)
# Fill mask pipeline supports only one ${mask_token} per sample
logits = outputs[0, masked_index.item(), :]
probs = logits.softmax(dim=0)
if target_ids is not None:
probs = probs[..., target_ids]
values, predictions = probs.topk(top_k)
for v, p in zip(values.tolist(), predictions.tolist()):
tokens = input_ids.numpy()
if target_ids is not None:
p = target_ids[p].tolist()
tokens[masked_index] = p
# Filter padding out:
tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
result.append(
{
"sequence": self.tokenizer.decode(tokens, skip_special_tokens=True),
"score": v,
"token": p,
"token_str": self.tokenizer.decode(p),
}
)
return result
def get_target_ids(self, targets, top_k=None):
if isinstance(targets, str):
targets = [targets]
try:
vocab = self.tokenizer.get_vocab()
except Exception:
vocab = {}
target_ids = []
for target in targets:
id_ = vocab.get(target, None)
if id_ is None:
input_ids = self.tokenizer(
target,
add_special_tokens=False,
return_attention_mask=False,
return_token_type_ids=False,
max_length=1,
truncation=True,
)["input_ids"]
if len(input_ids) == 0:
logger.warning(
f"The specified target token `{target}` does not exist in the model vocabulary. "
f"We cannot replace it with anything meaningful, ignoring it"
)
continue
id_ = input_ids[0]
# XXX: If users encounter this pass
# it becomes pretty slow, so let's make sure
# The warning enables them to fix the input to
# get faster performance.
logger.warning(
f"The specified target token `{target}` does not exist in the model vocabulary. "
f"Replacing with `{self.tokenizer.convert_ids_to_tokens(id_)}`."
)
target_ids.append(id_)
target_ids = list(set(target_ids))
if len(target_ids) == 0:
raise ValueError("At least one target must be provided when passed.")
target_ids = np.array(target_ids)
return target_ids
def _sanitize_parameters(self, top_k=None, targets=None):
postprocess_params = {}
if targets is not None:
target_ids = self.get_target_ids(targets, top_k)
postprocess_params["target_ids"] = target_ids
if top_k is not None:
postprocess_params["top_k"] = top_k
if self.tokenizer.mask_token_id is None:
raise PipelineException(
"fill-mask", self.model.base_model_prefix, "The tokenizer does not define a `mask_token`."
)
return {}, {}, postprocess_params
def __call__(self, inputs, *args, **kwargs):
"""
Fill the masked token in the text(s) given as inputs.
@ -156,126 +219,4 @@ class FillMaskPipeline(Pipeline):
- **token** (:obj:`int`) -- The predicted token id (to replace the masked one).
- **token** (:obj:`str`) -- The predicted token (to replace the masked one).
"""
model_inputs = self.get_model_inputs(inputs, *args, **kwargs)
self.ensure_exactly_one_mask_token(model_inputs)
if isinstance(model_inputs, list):
outputs = []
for model_input in model_inputs:
output = self._forward(model_input, return_tensors=True)
outputs.append(output)
batch_size = len(model_inputs)
else:
outputs = self._forward(model_inputs, return_tensors=True)
batch_size = outputs.shape[0] if self.framework == "tf" else outputs.size(0)
# top_k must be defined
if top_k is None:
top_k = self.top_k
results = []
if targets is None and self.targets is not None:
targets = self.targets
if targets is not None:
if isinstance(targets, str):
targets = [targets]
try:
vocab = self.tokenizer.get_vocab()
except Exception:
vocab = {}
target_ids = []
for target in targets:
id_ = vocab.get(target, None)
if id_ is None:
input_ids = self.tokenizer(
target,
add_special_tokens=False,
return_attention_mask=False,
return_token_type_ids=False,
max_length=1,
truncation=True,
)["input_ids"]
if len(input_ids) == 0:
logger.warning(
f"The specified target token `{target}` does not exist in the model vocabulary. "
f"We cannot replace it with anything meaningful, ignoring it"
)
continue
id_ = input_ids[0]
# XXX: If users encounter this pass
# it becomes pretty slow, so let's make sure
# The warning enables them to fix the input to
# get faster performance.
logger.warning(
f"The specified target token `{target}` does not exist in the model vocabulary. "
f"Replacing with `{self.tokenizer.convert_ids_to_tokens(id_)}`."
)
target_ids.append(id_)
target_ids = list(set(target_ids))
if len(target_ids) == 0:
raise ValueError("At least one target must be provided when passed.")
target_ids = np.array(target_ids)
# Cap top_k if there are targets
if top_k > target_ids.shape[0]:
top_k = target_ids.shape[0]
for i in range(batch_size):
if isinstance(model_inputs, list):
input_ids = model_inputs[i]["input_ids"][0]
else:
input_ids = model_inputs["input_ids"][i]
result = []
if self.framework == "tf":
masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()
# Fill mask pipeline supports only one ${mask_token} per sample
if isinstance(outputs, list):
logits = outputs[i][0, masked_index.item(), :]
else:
logits = outputs[i, masked_index.item(), :]
probs = tf.nn.softmax(logits)
if targets is not None:
probs = tf.gather_nd(probs, tf.reshape(target_ids, (-1, 1)))
topk = tf.math.top_k(probs, k=top_k)
values, predictions = topk.values.numpy(), topk.indices.numpy()
else:
masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False)
# Fill mask pipeline supports only one ${mask_token} per sample
if isinstance(outputs, list):
logits = outputs[i][0, masked_index.item(), :]
else:
logits = outputs[i, masked_index.item(), :]
probs = logits.softmax(dim=0)
if targets is not None:
probs = probs[..., target_ids]
values, predictions = probs.topk(top_k)
for v, p in zip(values.tolist(), predictions.tolist()):
tokens = input_ids.numpy()
if targets is not None:
p = target_ids[p].tolist()
tokens[masked_index] = p
# Filter padding out:
tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
result.append(
{
"sequence": self.tokenizer.decode(tokens, skip_special_tokens=True),
"score": v,
"token": p,
"token_str": self.tokenizer.decode(p),
}
)
# Append
results += [result]
if len(results) == 1:
return results[0]
return results
return super().__call__(inputs, **kwargs)

View File

@ -1,24 +1,17 @@
import os
from typing import TYPE_CHECKING, List, Optional, Union
from typing import List, Union
import requests
from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends
from ..utils import logging
from .base import PIPELINE_INIT_ARGS, Pipeline
if TYPE_CHECKING:
from ..modeling_tf_utils import TFPreTrainedModel
from ..modeling_utils import PreTrainedModel
if is_vision_available():
from PIL import Image
if is_torch_available():
import torch
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
logger = logging.get_logger(__name__)
@ -37,24 +30,15 @@ class ImageClassificationPipeline(Pipeline):
<https://huggingface.co/models?filter=image-classification>`__.
"""
def __init__(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel"],
feature_extractor: PreTrainedFeatureExtractor,
framework: Optional[str] = None,
**kwargs
):
super().__init__(model, feature_extractor=feature_extractor, framework=framework, **kwargs)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.framework == "tf":
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
requires_backends(self, "vision")
self.check_model_type(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING)
self.feature_extractor = feature_extractor
@staticmethod
def load_image(image: Union[str, "Image.Image"]):
if isinstance(image, str):
@ -77,7 +61,13 @@ class ImageClassificationPipeline(Pipeline):
image = image.convert("RGB")
return image
def __call__(self, images: Union[str, List[str], "Image", List["Image"]], top_k=5):
def _sanitize_parameters(self, top_k=None):
postprocess_params = {}
if top_k is not None:
postprocess_params["top_k"] = top_k
return {}, {}, postprocess_params
def __call__(self, images: Union[str, List[str], "Image", List["Image"]], **kwargs):
"""
Assign labels to the image(s) passed as inputs.
@ -106,34 +96,23 @@ class ImageClassificationPipeline(Pipeline):
- **label** (:obj:`str`) -- The label identified by the model.
- **score** (:obj:`int`) -- The score attributed by the model for that label.
"""
is_batched = isinstance(images, list)
return super().__call__(images, **kwargs)
if not is_batched:
images = [images]
def preprocess(self, image):
image = self.load_image(image)
model_inputs = self.feature_extractor(images=image, return_tensors="pt")
return model_inputs
images = [self.load_image(image) for image in images]
def _forward(self, model_inputs):
model_outputs = self.model(**model_inputs)
return model_outputs
def postprocess(self, model_outputs, top_k=5):
if top_k > self.model.config.num_labels:
top_k = self.model.config.num_labels
probs = model_outputs.logits.softmax(-1)[0]
scores, ids = probs.topk(top_k)
with torch.no_grad():
inputs = self.feature_extractor(images=images, return_tensors="pt")
outputs = self.model(**inputs)
probs = outputs.logits.softmax(-1)
scores, ids = probs.topk(top_k)
scores = scores.tolist()
ids = ids.tolist()
if not is_batched:
scores, ids = scores[0], ids[0]
labels = [{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]
else:
labels = []
for scores, ids in zip(scores, ids):
labels.append(
[{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]
)
return labels
scores = scores.tolist()
ids = ids.tolist()
return [{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]

View File

@ -1,17 +1,13 @@
import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Union
import requests
from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends
from ..utils import logging
from .base import PIPELINE_INIT_ARGS, Pipeline
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
if is_vision_available():
from PIL import Image
@ -40,24 +36,15 @@ class ObjectDetectionPipeline(Pipeline):
<https://huggingface.co/models?filter=object-detection>`__.
"""
def __init__(
self,
model: "PreTrainedModel",
feature_extractor: PreTrainedFeatureExtractor,
framework: Optional[str] = None,
**kwargs
):
super().__init__(model, feature_extractor=feature_extractor, framework=framework, **kwargs)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.framework == "tf":
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
requires_backends(self, "vision")
self.check_model_type(MODEL_FOR_OBJECT_DETECTION_MAPPING)
self.feature_extractor = feature_extractor
@staticmethod
def load_image(image: Union[str, "Image.Image"]):
if isinstance(image, str):
@ -80,11 +67,13 @@ class ObjectDetectionPipeline(Pipeline):
image = image.convert("RGB")
return image
def __call__(
self,
images: Union[str, List[str], "Image", List["Image"]],
threshold: Optional[float] = 0.9,
) -> Union[Predictions, List[Prediction]]:
def _sanitize_parameters(self, **kwargs):
postprocess_kwargs = {}
if "threshold" in kwargs:
postprocess_kwargs["threshold"] = kwargs["threshold"]
return {}, {}, postprocess_kwargs
def __call__(self, *args, **kwargs) -> Union[Predictions, List[Prediction]]:
"""
Detect objects (bounding boxes & classes) in the image(s) passed as inputs.
@ -112,47 +101,42 @@ class ObjectDetectionPipeline(Pipeline):
- **score** (:obj:`float`) -- The score attributed by the model for that label.
- **box** (:obj:`List[Dict[str, int]]`) -- The bounding box of detected object in image's original size.
"""
is_batched = isinstance(images, list)
if not is_batched:
images = [images]
return super().__call__(*args, **kwargs)
images = [self.load_image(image) for image in images]
def preprocess(self, image):
image = self.load_image(image)
target_size = torch.IntTensor([[image.height, image.width]])
inputs = self.feature_extractor(images=[image], return_tensors="pt")
inputs["target_size"] = target_size
return inputs
with torch.no_grad():
inputs = self.feature_extractor(images=images, return_tensors="pt")
outputs = self.model(**inputs)
def _forward(self, model_inputs):
target_size = model_inputs.pop("target_size")
outputs = self.model(**model_inputs)
model_outputs = {"outputs": outputs, "target_size": target_size}
return model_outputs
if self.framework == "pt":
target_sizes = torch.IntTensor([[im.height, im.width] for im in images])
else:
raise ValueError("The ObjectDetectionPipeline is only available in PyTorch.")
def postprocess(self, model_outputs, threshold=0.9):
raw_annotations = self.feature_extractor.post_process(model_outputs["outputs"], model_outputs["target_size"])
raw_annotation = raw_annotations[0]
keep = raw_annotation["scores"] > threshold
scores = raw_annotation["scores"][keep]
labels = raw_annotation["labels"][keep]
boxes = raw_annotation["boxes"][keep]
raw_annotations = self.feature_extractor.post_process(outputs, target_sizes)
annotations = []
for annotation in raw_annotations:
keep = annotation["scores"] > threshold
scores = annotation["scores"][keep]
labels = annotation["labels"][keep]
boxes = annotation["boxes"][keep]
raw_annotation["scores"] = scores.tolist()
raw_annotation["labels"] = [self.model.config.id2label[label.item()] for label in labels]
raw_annotation["boxes"] = [self._get_bounding_box(box) for box in boxes]
annotation["scores"] = scores.tolist()
annotation["labels"] = [self.model.config.id2label[label.item()] for label in labels]
annotation["boxes"] = [self._get_bounding_box(box) for box in boxes]
# {"scores": [...], ...} --> [{"score":x, ...}, ...]
keys = ["score", "label", "box"]
annotation = [
dict(zip(keys, vals))
for vals in zip(raw_annotation["scores"], raw_annotation["labels"], raw_annotation["boxes"])
]
# {"scores": [...], ...} --> [{"score":x, ...}, ...]
keys = ["score", "label", "box"]
annotation = [
dict(zip(keys, vals))
for vals in zip(annotation["scores"], annotation["labels"], annotation["boxes"])
]
annotations.append(annotation)
if not is_batched:
return annotations[0]
return annotations
return annotation
def _get_bounding_box(self, box: "torch.Tensor") -> Dict[str, int]:
"""

View File

@ -1,3 +1,4 @@
import warnings
from collections.abc import Iterable
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
@ -109,6 +110,7 @@ class QuestionAnsweringPipeline(Pipeline):
"""
default_input_names = "question,context"
handle_impossible_answer = False
def __init__(
self,
@ -158,6 +160,44 @@ class QuestionAnsweringPipeline(Pipeline):
else:
return SquadExample(None, question, context, None, None, None)
def _sanitize_parameters(
self,
padding=None,
topk=None,
top_k=None,
doc_stride=None,
max_answer_len=None,
max_seq_len=None,
max_question_len=None,
handle_impossible_answer=None,
**kwargs
):
# Set defaults values
preprocess_params = {}
if padding is not None:
preprocess_params["padding"] = padding
if doc_stride is not None:
preprocess_params["doc_stride"] = doc_stride
if max_question_len is not None:
preprocess_params["max_question_len"] = max_question_len
postprocess_params = {}
if topk is not None and top_k is None:
warnings.warn("topk parameter is deprecated, use top_k instead", UserWarning)
top_k = topk
if top_k is not None:
if top_k < 1:
raise ValueError(f"top_k parameter should be >= 1 (got {top_k})")
postprocess_params["top_k"] = top_k
if max_answer_len is not None:
if max_answer_len < 1:
raise ValueError(f"max_answer_len parameter should be >= 1 (got {max_answer_len}")
if max_answer_len is not None:
postprocess_params["max_answer_len"] = max_answer_len
if handle_impossible_answer is not None:
postprocess_params["handle_impossible_answer"] = handle_impossible_answer
return preprocess_params, {}, postprocess_params
def __call__(self, *args, **kwargs):
"""
Answer the question(s) given as inputs by using the context(s).
@ -201,211 +241,202 @@ class QuestionAnsweringPipeline(Pipeline):
- **end** (:obj:`int`) -- The character end index of the answer (in the tokenized version of the input).
- **answer** (:obj:`str`) -- The answer to the question.
"""
# Set defaults values
kwargs.setdefault("padding", "longest" if getattr(self.tokenizer, "pad_token", None) is not None else False)
kwargs.setdefault("topk", 1)
kwargs.setdefault("doc_stride", 128)
kwargs.setdefault("max_answer_len", 15)
kwargs.setdefault("max_seq_len", 384)
kwargs.setdefault("max_question_len", 64)
kwargs.setdefault("handle_impossible_answer", False)
if kwargs["topk"] < 1:
raise ValueError(f"topk parameter should be >= 1 (got {kwargs['topk']})")
if kwargs["max_answer_len"] < 1:
raise ValueError(f"max_answer_len parameter should be >= 1 (got {(kwargs['max_answer_len'])}")
# Convert inputs to features
examples = self._args_parser(*args, **kwargs)
if len(examples) == 1:
return super().__call__(examples[0], **kwargs)
return super().__call__(examples, **kwargs)
def preprocess(self, example, padding="do_not_pad", doc_stride=128, max_question_len=64, max_seq_len=384):
if not self.tokenizer.is_fast:
features_list = [
squad_convert_examples_to_features(
examples=[example],
tokenizer=self.tokenizer,
max_seq_length=kwargs["max_seq_len"],
doc_stride=kwargs["doc_stride"],
max_query_length=kwargs["max_question_len"],
padding_strategy=PaddingStrategy.MAX_LENGTH.value,
is_training=False,
tqdm_enabled=False,
)
for example in examples
]
features = squad_convert_examples_to_features(
examples=[example],
tokenizer=self.tokenizer,
max_seq_length=max_seq_len,
doc_stride=doc_stride,
max_query_length=max_question_len,
padding_strategy=PaddingStrategy.MAX_LENGTH,
is_training=False,
tqdm_enabled=False,
)
else:
features_list = []
for example in examples:
# Define the side we want to truncate / pad and the text/pair sorting
question_first = bool(self.tokenizer.padding_side == "right")
# Define the side we want to truncate / pad and the text/pair sorting
question_first = self.tokenizer.padding_side == "right"
encoded_inputs = self.tokenizer(
text=example.question_text if question_first else example.context_text,
text_pair=example.context_text if question_first else example.question_text,
padding=kwargs["padding"],
truncation="only_second" if question_first else "only_first",
max_length=kwargs["max_seq_len"],
stride=kwargs["doc_stride"],
return_tensors="np",
return_token_type_ids=True,
return_overflowing_tokens=True,
return_offsets_mapping=True,
return_special_tokens_mask=True,
)
encoded_inputs = self.tokenizer(
text=example.question_text if question_first else example.context_text,
text_pair=example.context_text if question_first else example.question_text,
padding=padding,
truncation="only_second" if question_first else "only_first",
max_length=max_seq_len,
stride=doc_stride,
return_tensors="np",
return_token_type_ids=True,
return_overflowing_tokens=True,
return_offsets_mapping=True,
return_special_tokens_mask=True,
)
# When the input is too long, it's converted in a batch of inputs with overflowing tokens
# and a stride of overlap between the inputs. If a batch of inputs is given, a special output
# "overflow_to_sample_mapping" indicate which member of the encoded batch belong to which original batch sample.
# Here we tokenize examples one-by-one so we don't need to use "overflow_to_sample_mapping".
# "num_span" is the number of output samples generated from the overflowing tokens.
num_spans = len(encoded_inputs["input_ids"])
# When the input is too long, it's converted in a batch of inputs with overflowing tokens
# and a stride of overlap between the inputs. If a batch of inputs is given, a special output
# "overflow_to_sample_mapping" indicate which member of the encoded batch belong to which original batch sample.
# Here we tokenize examples one-by-one so we don't need to use "overflow_to_sample_mapping".
# "num_span" is the number of output samples generated from the overflowing tokens.
num_spans = len(encoded_inputs["input_ids"])
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
# We put 0 on the tokens from the context and 1 everywhere else (question and special tokens)
p_mask = np.asarray(
[
[tok != 1 if question_first else 0 for tok in encoded_inputs.sequence_ids(span_id)]
for span_id in range(num_spans)
]
)
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
# We put 0 on the tokens from the context and 1 everywhere else (question and special tokens)
p_mask = np.asarray(
[
[tok != 1 if question_first else 0 for tok in encoded_inputs.sequence_ids(span_id)]
for span_id in range(num_spans)
]
)
# keep the cls_token unmasked (some models use it to indicate unanswerable questions)
if self.tokenizer.cls_token_id is not None:
cls_index = np.nonzero(encoded_inputs["input_ids"] == self.tokenizer.cls_token_id)
p_mask[cls_index] = 0
# keep the cls_token unmasked (some models use it to indicate unanswerable questions)
if self.tokenizer.cls_token_id is not None:
cls_index = np.nonzero(encoded_inputs["input_ids"] == self.tokenizer.cls_token_id)
p_mask[cls_index] = 0
features = []
for span_idx in range(num_spans):
features.append(
SquadFeatures(
input_ids=encoded_inputs["input_ids"][span_idx],
attention_mask=encoded_inputs["attention_mask"][span_idx],
token_type_ids=encoded_inputs["token_type_ids"][span_idx],
p_mask=p_mask[span_idx].tolist(),
encoding=encoded_inputs[span_idx],
# We don't use the rest of the values - and actually
# for Fast tokenizer we could totally avoid using SquadFeatures and SquadExample
cls_index=None,
token_to_orig_map={},
example_index=0,
unique_id=0,
paragraph_len=0,
token_is_max_context=0,
tokens=[],
start_position=0,
end_position=0,
is_impossible=False,
qas_id=None,
)
features = []
for span_idx in range(num_spans):
features.append(
SquadFeatures(
input_ids=encoded_inputs["input_ids"][span_idx],
attention_mask=encoded_inputs["attention_mask"][span_idx],
token_type_ids=encoded_inputs["token_type_ids"][span_idx],
p_mask=p_mask[span_idx].tolist(),
encoding=encoded_inputs[span_idx],
# We don't use the rest of the values - and actually
# for Fast tokenizer we could totally avoid using SquadFeatures and SquadExample
cls_index=None,
token_to_orig_map={},
example_index=0,
unique_id=0,
paragraph_len=0,
token_is_max_context=0,
tokens=[],
start_position=0,
end_position=0,
is_impossible=False,
qas_id=None,
)
features_list.append(features)
all_answers = []
for features, example in zip(features_list, examples):
model_input_names = self.tokenizer.model_input_names
fw_args = {k: [feature.__dict__[k] for feature in features] for k in model_input_names}
# Manage tensor allocation on correct device
with self.device_placement():
if self.framework == "tf":
fw_args = {k: tf.constant(v) for (k, v) in fw_args.items()}
start, end = self.model(fw_args)[:2]
start, end = start.numpy(), end.numpy()
else:
with torch.no_grad():
# Retrieve the score for the context tokens only (removing question tokens)
fw_args = {k: torch.tensor(v, device=self.device) for (k, v) in fw_args.items()}
# On Windows, the default int type in numpy is np.int32 so we get some non-long tensors.
fw_args = {k: v.long() if v.dtype == torch.int32 else v for (k, v) in fw_args.items()}
start, end = self.model(**fw_args)[:2]
start, end = start.cpu().numpy(), end.cpu().numpy()
min_null_score = 1000000 # large and positive
answers = []
for (feature, start_, end_) in zip(features, start, end):
# Ensure padded tokens & question tokens cannot belong to the set of candidate answers.
undesired_tokens = np.abs(np.array(feature.p_mask) - 1) & feature.attention_mask
# Generate mask
undesired_tokens_mask = undesired_tokens == 0.0
# Make sure non-context indexes in the tensor cannot contribute to the softmax
start_ = np.where(undesired_tokens_mask, -10000.0, start_)
end_ = np.where(undesired_tokens_mask, -10000.0, end_)
# Normalize logits and spans to retrieve the answer
start_ = np.exp(start_ - np.log(np.sum(np.exp(start_), axis=-1, keepdims=True)))
end_ = np.exp(end_ - np.log(np.sum(np.exp(end_), axis=-1, keepdims=True)))
if kwargs["handle_impossible_answer"]:
min_null_score = min(min_null_score, (start_[0] * end_[0]).item())
# Mask CLS
start_[0] = end_[0] = 0.0
starts, ends, scores = self.decode(
start_, end_, kwargs["topk"], kwargs["max_answer_len"], undesired_tokens
)
if not self.tokenizer.is_fast:
char_to_word = np.array(example.char_to_word_offset)
return {"features": features, "example": example}
# Convert the answer (tokens) back to the original text
# Score: score from the model
# Start: Index of the first character of the answer in the context string
# End: Index of the character following the last character of the answer in the context string
# Answer: Plain text of the answer
for s, e, score in zip(starts, ends, scores):
answers.append(
{
"score": score.item(),
"start": np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(),
"end": np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(),
"answer": " ".join(
example.doc_tokens[feature.token_to_orig_map[s] : feature.token_to_orig_map[e] + 1]
),
}
)
else:
# Convert the answer (tokens) back to the original text
# Score: score from the model
# Start: Index of the first character of the answer in the context string
# End: Index of the character following the last character of the answer in the context string
# Answer: Plain text of the answer
question_first = bool(self.tokenizer.padding_side == "right")
enc = feature.encoding
def _forward(self, model_inputs):
features = model_inputs["features"]
example = model_inputs["example"]
model_input_names = self.tokenizer.model_input_names
fw_args = {k: [feature.__dict__[k] for feature in features] for k in model_input_names}
# Sometimes the max probability token is in the middle of a word so:
# - we start by finding the right word containing the token with `token_to_word`
# - then we convert this word in a character span with `word_to_chars`
sequence_index = 1 if question_first else 0
for s, e, score in zip(starts, ends, scores):
try:
start_word = enc.token_to_word(s)
end_word = enc.token_to_word(e)
start_index = enc.word_to_chars(start_word, sequence_index=sequence_index)[0]
end_index = enc.word_to_chars(end_word, sequence_index=sequence_index)[1]
except Exception:
# Some tokenizers don't really handle words. Keep to offsets then.
start_index = enc.offsets[s][0]
end_index = enc.offsets[e][1]
if self.framework == "tf":
fw_args = {k: tf.constant(v) for (k, v) in fw_args.items()}
start, end = self.model(fw_args)[:2]
start, end = start.numpy(), end.numpy()
elif self.framework == "pt":
# Retrieve the score for the context tokens only (removing question tokens)
fw_args = {k: torch.tensor(v, device=self.device) for (k, v) in fw_args.items()}
# On Windows, the default int type in numpy is np.int32 so we get some non-long tensors.
fw_args = {k: v.long() if v.dtype == torch.int32 else v for (k, v) in fw_args.items()}
start, end = self.model(**fw_args)[:2]
start, end = start.cpu().numpy(), end.cpu().numpy()
return {"start": start, "end": end, "features": features, "example": example}
answers.append(
{
"score": score.item(),
"start": start_index,
"end": end_index,
"answer": example.context_text[start_index:end_index],
}
)
def postprocess(
self,
model_outputs,
top_k=1,
handle_impossible_answer=False,
max_answer_len=15,
):
min_null_score = 1000000 # large and positive
answers = []
start_ = model_outputs["start"][0]
end_ = model_outputs["end"][0]
feature = model_outputs["features"][0]
example = model_outputs["example"]
# Ensure padded tokens & question tokens cannot belong to the set of candidate answers.
undesired_tokens = np.abs(np.array(feature.p_mask) - 1) & feature.attention_mask
if kwargs["handle_impossible_answer"]:
answers.append({"score": min_null_score, "start": 0, "end": 0, "answer": ""})
# Generate mask
undesired_tokens_mask = undesired_tokens == 0.0
answers = sorted(answers, key=lambda x: x["score"], reverse=True)[: kwargs["topk"]]
all_answers += answers
# Make sure non-context indexes in the tensor cannot contribute to the softmax
start_ = np.where(undesired_tokens_mask, -10000.0, start_)
end_ = np.where(undesired_tokens_mask, -10000.0, end_)
if len(all_answers) == 1:
return all_answers[0]
return all_answers
# Normalize logits and spans to retrieve the answer
start_ = np.exp(start_ - np.log(np.sum(np.exp(start_), axis=-1, keepdims=True)))
end_ = np.exp(end_ - np.log(np.sum(np.exp(end_), axis=-1, keepdims=True)))
if handle_impossible_answer:
min_null_score = min(min_null_score, (start_[0] * end_[0]).item())
# Mask CLS
start_[0] = end_[0] = 0.0
starts, ends, scores = self.decode(start_, end_, top_k, max_answer_len, undesired_tokens)
if not self.tokenizer.is_fast:
char_to_word = np.array(example.char_to_word_offset)
# Convert the answer (tokens) back to the original text
# Score: score from the model
# Start: Index of the first character of the answer in the context string
# End: Index of the character following the last character of the answer in the context string
# Answer: Plain text of the answer
for s, e, score in zip(starts, ends, scores):
answers.append(
{
"score": score.item(),
"start": np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(),
"end": np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(),
"answer": " ".join(
example.doc_tokens[feature.token_to_orig_map[s] : feature.token_to_orig_map[e] + 1]
),
}
)
else:
# Convert the answer (tokens) back to the original text
# Score: score from the model
# Start: Index of the first character of the answer in the context string
# End: Index of the character following the last character of the answer in the context string
# Answer: Plain text of the answer
question_first = bool(self.tokenizer.padding_side == "right")
enc = feature.encoding
# Sometimes the max probability token is in the middle of a word so:
# - we start by finding the right word containing the token with `token_to_word`
# - then we convert this word in a character span with `word_to_chars`
sequence_index = 1 if question_first else 0
for s, e, score in zip(starts, ends, scores):
try:
start_word = enc.token_to_word(s)
end_word = enc.token_to_word(e)
start_index = enc.word_to_chars(start_word, sequence_index=sequence_index)[0]
end_index = enc.word_to_chars(end_word, sequence_index=sequence_index)[1]
except Exception:
# Some tokenizers don't really handle words. Keep to offsets then.
start_index = enc.offsets[s][0]
end_index = enc.offsets[e][1]
answers.append(
{
"score": score.item(),
"start": start_index,
"end": end_index,
"answer": example.context_text[start_index:end_index],
}
)
if handle_impossible_answer:
answers.append({"score": min_null_score, "start": 0, "end": 0, "answer": ""})
answers = sorted(answers, key=lambda x: x["score"], reverse=True)[:top_k]
if len(answers) == 1:
return answers[0]
return answers
def decode(
self, start: np.ndarray, end: np.ndarray, topk: int, max_answer_len: int, undesired_tokens: np.ndarray

View File

@ -17,7 +17,7 @@ class TableQuestionAnsweringArgumentHandler(ArgumentHandler):
Handles arguments for the TableQuestionAnsweringPipeline
"""
def __call__(self, table=None, query=None, sequential=False, padding=True, truncation=True):
def __call__(self, table=None, query=None, **kwargs):
# Returns tqa_pipeline_inputs of shape:
# [
# {"table": pd.DataFrame, "query": List[str]},
@ -60,7 +60,7 @@ class TableQuestionAnsweringArgumentHandler(ArgumentHandler):
tqa_pipeline_input["table"] = pd.DataFrame(tqa_pipeline_input["table"])
return tqa_pipeline_inputs, sequential, padding, truncation
return tqa_pipeline_inputs
@add_end_docstrings(PIPELINE_INIT_ARGS)
@ -235,52 +235,76 @@ class TableQuestionAnsweringPipeline(Pipeline):
- **cells** (:obj:`List[str]`) -- List of strings made up of the answer cell values.
- **aggregator** (:obj:`str`) -- If the model has an aggregator, this returns the aggregator.
"""
pipeline_inputs, sequential, padding, truncation = self._args_parser(*args, **kwargs)
batched_answers = []
for pipeline_input in pipeline_inputs:
table, query = pipeline_input["table"], pipeline_input["query"]
if table.empty:
raise ValueError("table is empty")
if not query:
raise ValueError("query is empty")
inputs = self.tokenizer(
table, query, return_tensors=self.framework, truncation="drop_rows_to_fit", padding=padding
)
pipeline_inputs = self._args_parser(*args, **kwargs)
outputs = self.sequential_inference(**inputs) if sequential else self.batch_inference(**inputs)
results = super().__call__(pipeline_inputs, **kwargs)
if len(results) == 1:
return results[0]
return results
if self.aggregate:
logits, logits_agg = outputs[:2]
predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits.detach(), logits_agg)
answer_coordinates_batch, agg_predictions = predictions
aggregators = {i: self.model.config.aggregation_labels[pred] for i, pred in enumerate(agg_predictions)}
def _sanitize_parameters(self, sequential=None, padding=None, truncation=None, **kwargs):
preprocess_params = {}
if padding is not None:
preprocess_params["padding"] = padding
if truncation is not None:
preprocess_params["truncation"] = truncation
no_agg_label_index = self.model.config.no_aggregation_label_index
aggregators_prefix = {
i: aggregators[i] + " > " for i, pred in enumerate(agg_predictions) if pred != no_agg_label_index
}
else:
logits = outputs[0]
predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits.detach())
answer_coordinates_batch = predictions[0]
aggregators = {}
aggregators_prefix = {}
forward_params = {}
if sequential is not None:
forward_params["sequential"] = sequential
return preprocess_params, forward_params, {}
answers = []
for index, coordinates in enumerate(answer_coordinates_batch):
cells = [table.iat[coordinate] for coordinate in coordinates]
aggregator = aggregators.get(index, "")
aggregator_prefix = aggregators_prefix.get(index, "")
answer = {
"answer": aggregator_prefix + ", ".join(cells),
"coordinates": coordinates,
"cells": [table.iat[coordinate] for coordinate in coordinates],
}
if aggregator:
answer["aggregator"] = aggregator
def preprocess(self, pipeline_input, sequential=None, padding=True, truncation="drop_rows_to_fit"):
table, query = pipeline_input["table"], pipeline_input["query"]
if table.empty:
raise ValueError("table is empty")
if query is None or query == "":
raise ValueError("query is empty")
inputs = self.tokenizer(table, query, return_tensors=self.framework, truncation=truncation, padding=padding)
inputs["table"] = table
return inputs
answers.append(answer)
if len(answer) == 0:
raise PipelineException("Empty answer")
batched_answers.append(answers if len(answers) > 1 else answers[0])
return batched_answers if len(batched_answers) > 1 else batched_answers[0]
def _forward(self, model_inputs, sequential=False):
table = model_inputs.pop("table")
outputs = self.sequential_inference(**model_inputs) if sequential else self.batch_inference(**model_inputs)
model_outputs = {"model_inputs": model_inputs, "table": table, "outputs": outputs}
return model_outputs
def postprocess(self, model_outputs):
inputs = model_outputs["model_inputs"]
table = model_outputs["table"]
outputs = model_outputs["outputs"]
if self.aggregate:
logits, logits_agg = outputs[:2]
predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits.detach(), logits_agg)
answer_coordinates_batch, agg_predictions = predictions
aggregators = {i: self.model.config.aggregation_labels[pred] for i, pred in enumerate(agg_predictions)}
no_agg_label_index = self.model.config.no_aggregation_label_index
aggregators_prefix = {
i: aggregators[i] + " > " for i, pred in enumerate(agg_predictions) if pred != no_agg_label_index
}
else:
logits = outputs[0]
predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits.detach())
answer_coordinates_batch = predictions[0]
aggregators = {}
aggregators_prefix = {}
answers = []
for index, coordinates in enumerate(answer_coordinates_batch):
cells = [table.iat[coordinate] for coordinate in coordinates]
aggregator = aggregators.get(index, "")
aggregator_prefix = aggregators_prefix.get(index, "")
answer = {
"answer": aggregator_prefix + ", ".join(cells),
"coordinates": coordinates,
"cells": [table.iat[coordinate] for coordinate in coordinates],
}
if aggregator:
answer["aggregator"] = aggregator
answers.append(answer)
if len(answer) == 0:
raise PipelineException("Empty answer")
return answers if len(answers) > 1 else answers[0]

View File

@ -1,4 +1,4 @@
from typing import Optional
import enum
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
from ..tokenization_utils import TruncationStrategy
@ -17,6 +17,11 @@ if is_torch_available():
logger = logging.get_logger(__name__)
class ReturnType(enum.Enum):
TENSORS = 0
TEXT = 1
@add_end_docstrings(PIPELINE_INIT_ARGS)
class Text2TextGenerationPipeline(Pipeline):
"""
@ -46,6 +51,32 @@ class Text2TextGenerationPipeline(Pipeline):
else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
)
def _sanitize_parameters(
self,
return_tensors=None,
return_text=None,
return_type=None,
clean_up_tokenization_spaces=None,
truncation=None,
**generate_kwargs
):
preprocess_params = {}
if truncation is not None:
preprocess_params["truncation"] = truncation
forward_params = generate_kwargs
postprocess_params = {}
if return_tensors is not None and return_type is None:
return_type = ReturnType.TENSORS if return_tensors else ReturnType.TEXT
if return_type is not None:
postprocess_params["return_type"] = return_type
if clean_up_tokenization_spaces is not None:
postprocess_params["clean_up_tokenization_spaces"] = clean_up_tokenization_spaces
return preprocess_params, forward_params, postprocess_params
def check_inputs(self, input_length: int, min_length: int, max_length: int):
"""
Checks whether there might be something wrong with given input with regard to the model.
@ -55,9 +86,8 @@ class Text2TextGenerationPipeline(Pipeline):
def _parse_and_tokenize(self, *args, truncation):
prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
if isinstance(args[0], list):
assert (
self.tokenizer.pad_token_id is not None
), "Please make sure that the tokenizer has a pad_token_id when using a batch input"
if self.tokenizer.pad_token_id is None:
raise ValueError("Please make sure that the tokenizer has a pad_token_id when using a batch input")
args = ([prefix + arg for arg in args[0]],)
padding = True
@ -68,21 +98,13 @@ class Text2TextGenerationPipeline(Pipeline):
raise ValueError(
f" `args[0]`: {args[0]} have the wrong format. The should be either of type `str` or type `list`"
)
inputs = super()._parse_and_tokenize(*args, padding=padding, truncation=truncation)
inputs = self.tokenizer(*args, padding=padding, truncation=truncation, return_tensors=self.framework)
# This is produced by tokenizers but is an invalid generate kwargs
if "token_type_ids" in inputs:
del inputs["token_type_ids"]
return inputs
def __call__(
self,
*args,
return_tensors=False,
return_text=True,
clean_up_tokenization_spaces=False,
truncation=TruncationStrategy.DO_NOT_TRUNCATE,
**generate_kwargs
):
def __call__(self, *args, **kwargs):
r"""
Generate the output text(s) using text(s) given as inputs.
@ -111,43 +133,40 @@ class Text2TextGenerationPipeline(Pipeline):
-- The token ids of the generated text.
"""
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
result = super().__call__(*args, **kwargs)
if isinstance(result, dict):
return [result]
return result
with self.device_placement():
inputs = self._parse_and_tokenize(*args, truncation=truncation)
return self._generate(inputs, return_tensors, return_text, clean_up_tokenization_spaces, generate_kwargs)
def preprocess(self, inputs, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs):
inputs = self._parse_and_tokenize(inputs, truncation=truncation, **kwargs)
return inputs
def _generate(
self, inputs, return_tensors: bool, return_text: bool, clean_up_tokenization_spaces: bool, generate_kwargs
):
def _forward(self, model_inputs, **generate_kwargs):
if self.framework == "pt":
inputs = self.ensure_tensor_on_device(**inputs)
input_length = inputs["input_ids"].shape[-1]
input_length = model_inputs["input_ids"].shape[-1]
elif self.framework == "tf":
input_length = tf.shape(inputs["input_ids"])[-1].numpy()
input_length = tf.shape(model_inputs["input_ids"])[-1].numpy()
min_length = generate_kwargs.get("min_length", self.model.config.min_length)
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
self.check_inputs(input_length, min_length, max_length)
generate_kwargs["min_length"] = generate_kwargs.get("min_length", self.model.config.min_length)
generate_kwargs["max_length"] = generate_kwargs.get("max_length", self.model.config.max_length)
self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"])
output_ids = self.model.generate(**model_inputs, **generate_kwargs)
return {"output_ids": output_ids}
generate_kwargs.update(inputs)
generations = self.model.generate(
**generate_kwargs,
)
results = []
for generation in generations:
record = {}
if return_tensors:
record[f"{self.return_name}_token_ids"] = generation
if return_text:
record[f"{self.return_name}_text"] = self.tokenizer.decode(
generation,
def postprocess(self, model_outputs, return_type=ReturnType.TEXT, clean_up_tokenization_spaces=False):
record = {}
if return_type == ReturnType.TENSORS:
record = {f"{self.return_name}_token_ids": model_outputs}
elif return_type == ReturnType.TEXT:
record = {
f"{self.return_name}_text": self.tokenizer.decode(
model_outputs["output_ids"][0],
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
results.append(record)
return results
}
return record
@add_end_docstrings(PIPELINE_INIT_ARGS)
@ -239,23 +258,6 @@ class TranslationPipeline(Text2TextGenerationPipeline):
# Used in the return key of the pipeline.
return_name = "translation"
src_lang: Optional[str] = None
tgt_lang: Optional[str] = None
def __init__(self, *args, src_lang=None, tgt_lang=None, **kwargs):
super().__init__(*args, **kwargs)
if src_lang is not None:
self.src_lang = src_lang
if tgt_lang is not None:
self.tgt_lang = tgt_lang
if src_lang is None and tgt_lang is None:
# Backward compatibility, direct arguments use is preferred.
task = kwargs.get("task", "")
items = task.split("_")
if task and len(items) == 4:
# translation, XX, to YY
self.src_lang = items[1]
self.tgt_lang = items[3]
def check_inputs(self, input_length: int, min_length: int, max_length: int):
if input_length > 0.9 * max_length:
@ -265,25 +267,31 @@ class TranslationPipeline(Text2TextGenerationPipeline):
)
return True
def _parse_and_tokenize(self, *args, src_lang, tgt_lang, truncation):
def preprocess(self, *args, truncation=TruncationStrategy.DO_NOT_TRUNCATE, src_lang=None, tgt_lang=None):
if getattr(self.tokenizer, "_build_translation_inputs", None):
return self.tokenizer._build_translation_inputs(
*args, return_tensors=self.framework, src_lang=src_lang, tgt_lang=tgt_lang, truncation=truncation
*args, return_tensors=self.framework, truncation=truncation, src_lang=src_lang, tgt_lang=tgt_lang
)
else:
return super()._parse_and_tokenize(*args, truncation=truncation)
def __call__(
self,
*args,
return_tensors=False,
return_text=True,
clean_up_tokenization_spaces=False,
truncation=TruncationStrategy.DO_NOT_TRUNCATE,
src_lang=None,
tgt_lang=None,
**generate_kwargs
):
def _sanitize_parameters(self, src_lang=None, tgt_lang=None, **kwargs):
preprocess_params, forward_params, postprocess_params = super()._sanitize_parameters(**kwargs)
if src_lang is not None:
preprocess_params["src_lang"] = src_lang
if tgt_lang is not None:
preprocess_params["tgt_lang"] = tgt_lang
if src_lang is None and tgt_lang is None:
# Backward compatibility, direct arguments use is preferred.
task = kwargs.get("task", self.task)
items = task.split("_")
if task and len(items) == 4:
# translation, XX, to YY
preprocess_params["src_lang"] = items[1]
preprocess_params["tgt_lang"] = items[3]
return preprocess_params, forward_params, postprocess_params
def __call__(self, *args, **kwargs):
r"""
Translate the text(s) given as inputs.
@ -313,10 +321,4 @@ class TranslationPipeline(Text2TextGenerationPipeline):
- **translation_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
-- The token ids of the translation.
"""
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
src_lang = src_lang if src_lang is not None else self.src_lang
tgt_lang = tgt_lang if tgt_lang is not None else self.tgt_lang
with self.device_placement():
inputs = self._parse_and_tokenize(*args, truncation=truncation, src_lang=src_lang, tgt_lang=tgt_lang)
return self._generate(inputs, return_tensors, return_text, clean_up_tokenization_spaces, generate_kwargs)
return super().__call__(*args, **kwargs)

View File

@ -1,9 +1,9 @@
from typing import Optional
from typing import Dict
import numpy as np
from ..file_utils import ExplicitEnum, add_end_docstrings, is_tf_available, is_torch_available
from .base import PIPELINE_INIT_ARGS, Pipeline
from .base import PIPELINE_INIT_ARGS, GenericTensor, Pipeline
if is_tf_available():
@ -61,9 +61,10 @@ class TextClassificationPipeline(Pipeline):
<https://huggingface.co/models?filter=text-classification>`__.
"""
task = "text-classification"
return_all_scores = False
function_to_apply = ClassificationFunction.NONE
def __init__(self, return_all_scores: bool = None, function_to_apply: str = None, **kwargs):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.check_model_type(
@ -72,22 +73,24 @@ class TextClassificationPipeline(Pipeline):
else MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
)
def _sanitize_parameters(self, return_all_scores=None, function_to_apply=None, **tokenizer_kwargs):
preprocess_params = tokenizer_kwargs
postprocess_params = {}
if hasattr(self.model.config, "return_all_scores") and return_all_scores is None:
return_all_scores = self.model.config.return_all_scores
if hasattr(self.model.config, "function_to_apply") and function_to_apply is None:
function_to_apply = self.model.config.function_to_apply
if return_all_scores is not None:
postprocess_params["return_all_scores"] = return_all_scores
self.return_all_scores = return_all_scores if return_all_scores is not None else False
self.function_to_apply = function_to_apply if function_to_apply is not None else None
if isinstance(function_to_apply, str):
function_to_apply = ClassificationFunction[function_to_apply.upper()]
def __call__(
self,
*args,
return_all_scores: Optional[bool] = None,
function_to_apply: Optional[ClassificationFunction] = None,
**kwargs
):
if function_to_apply is not None:
postprocess_params["function_to_apply"] = function_to_apply
return preprocess_params, {}, postprocess_params
def __call__(self, *args, **kwargs):
"""
Classify the text(s) given as inputs.
@ -120,19 +123,32 @@ class TextClassificationPipeline(Pipeline):
If ``self.return_all_scores=True``, one such dictionary is returned per label.
"""
outputs = super().__call__(*args, **kwargs)
return super().__call__(*args, **kwargs)
return_all_scores = return_all_scores if return_all_scores is not None else self.return_all_scores
function_to_apply = function_to_apply if function_to_apply is not None else self.function_to_apply
def preprocess(self, inputs, **tokenizer_kwargs) -> Dict[str, GenericTensor]:
return_tensors = self.framework
return self.tokenizer(inputs, return_tensors=return_tensors, **tokenizer_kwargs)
def _forward(self, model_inputs):
return self.model(**model_inputs)
def postprocess(self, model_outputs, function_to_apply=None, return_all_scores=False):
# Default value before `set_parameters`
if function_to_apply is None:
if self.model.config.problem_type == "multi_label_classification" or self.model.config.num_labels == 1:
function_to_apply = ClassificationFunction.SIGMOID
elif self.model.config.problem_type == "single_label_classification" or self.model.config.num_labels > 1:
function_to_apply = ClassificationFunction.SOFTMAX
elif hasattr(self.model.config, "function_to_apply") and function_to_apply is None:
function_to_apply = self.model.config.function_to_apply
else:
function_to_apply = ClassificationFunction.NONE
if isinstance(function_to_apply, str):
function_to_apply = ClassificationFunction[function_to_apply.upper()]
outputs = model_outputs["logits"][0]
if self.framework == "pt":
outputs = outputs.cpu().numpy()
else:
outputs = outputs.numpy()
if function_to_apply == ClassificationFunction.SIGMOID:
scores = sigmoid(outputs)
@ -144,11 +160,13 @@ class TextClassificationPipeline(Pipeline):
raise ValueError(f"Unrecognized `function_to_apply` argument: {function_to_apply}")
if return_all_scores:
return [
[{"label": self.model.config.id2label[i], "score": score.item()} for i, score in enumerate(item)]
for item in scores
]
return [{"label": self.model.config.id2label[i], "score": score.item()} for i, score in enumerate(scores)]
else:
return [
{"label": self.model.config.id2label[item.argmax()], "score": item.max().item()} for item in scores
]
return {"label": self.model.config.id2label[scores.argmax().item()], "score": scores.max().item()}
def run_multi(self, inputs, preprocess_params, forward_params, postprocess_params):
return [self.run_single(item, preprocess_params, forward_params, postprocess_params)[0] for item in inputs]
def run_single(self, inputs, preprocess_params, forward_params, postprocess_params):
"This pipeline is odd, and return a list when single item is run"
return [super().run_single(inputs, preprocess_params, forward_params, postprocess_params)]

View File

@ -1,9 +1,17 @@
import enum
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING
from ..file_utils import add_end_docstrings
from .base import PIPELINE_INIT_ARGS, Pipeline
class ReturnType(enum.Enum):
TENSORS = 0
NEW_TEXT = 1
FULL_TEXT = 2
@add_end_docstrings(PIPELINE_INIT_ARGS)
class TextGenerationPipeline(Pipeline):
"""
@ -32,29 +40,72 @@ class TextGenerationPipeline(Pipeline):
begging for his blessing. <eod> </s> <eos>
"""
ALLOWED_MODELS = [
"XLNetLMHeadModel",
"TransfoXLLMHeadModel",
"ReformerModelWithLMHead",
"GPT2LMHeadModel",
"GPTJForCausalLM",
"GPTNeoForCausalLM",
"OpenAIGPTLMHeadModel",
"CTRLLMHeadModel",
"TFXLNetLMHeadModel",
"TFTransfoXLLMHeadModel",
"TFGPT2LMHeadModel",
"TFOpenAIGPTLMHeadModel",
"TFCTRLLMHeadModel",
]
def __init__(self, *args, return_full_text=True, **kwargs):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.check_model_type(
TF_MODEL_FOR_CAUSAL_LM_MAPPING if self.framework == "tf" else MODEL_FOR_CAUSAL_LM_MAPPING
)
if "prefix" not in self._preprocess_params:
# This is very specific. The logic is quite complex and needs to be done
# as a "default".
# It also defines both some preprocess_kwargs and generate_kwargs
# which is why we cannot put them in their respective methods.
prefix = None
if self.model.config.prefix is not None:
prefix = self.model.config.prefix
if prefix is None and self.model.__class__.__name__ in [
"XLNetLMHeadModel",
"TransfoXLLMHeadModel",
"TFXLNetLMHeadModel",
"TFTransfoXLLMHeadModel",
]:
# For XLNet and TransformerXL we add an article to the prompt to give more state to the model.
prefix = self.XL_PREFIX
if prefix is not None:
# Recalculate some generate_kwargs linked to prefix.
preprocess_params, forward_params, _ = self._sanitize_parameters(prefix=prefix, **self._forward_params)
self._preprocess_params = {**self._preprocess_params, **preprocess_params}
self._forward_params = {**self._forward_params, **forward_params}
self.return_full_text = return_full_text
def _sanitize_parameters(
self,
return_full_text=None,
return_tensors=None,
return_text=None,
return_type=None,
clean_up_tokenization_spaces=None,
prefix=None,
**generate_kwargs
):
preprocess_params = {}
if prefix is not None:
preprocess_params["prefix"] = prefix
if prefix:
prefix_inputs = self.tokenizer(
prefix, padding=False, add_special_tokens=False, return_tensors=self.framework
)
prefix_length = prefix_inputs["input_ids"].shape[-1]
if "max_length" in generate_kwargs:
generate_kwargs["max_length"] += prefix_length
else:
generate_kwargs["max_length"] = self.model.config.max_length + prefix_length
if "min_length" in generate_kwargs:
generate_kwargs["min_length"] += prefix_length
forward_params = generate_kwargs
postprocess_params = {}
if return_full_text is not None and return_type is None:
return_type = ReturnType.FULL_TEXT if return_full_text else ReturnType.NEW_TEXT
if return_tensors is not None and return_type is None:
return_type = ReturnType.TENSORS
if return_type is not None:
postprocess_params["return_type"] = return_type
if clean_up_tokenization_spaces is not None:
postprocess_params["clean_up_tokenization_spaces"] = clean_up_tokenization_spaces
return preprocess_params, forward_params, postprocess_params
# overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments
def _parse_and_tokenize(self, *args, **kwargs):
@ -67,16 +118,7 @@ class TextGenerationPipeline(Pipeline):
return super()._parse_and_tokenize(*args, **kwargs)
def __call__(
self,
text_inputs,
return_tensors=False,
return_text=True,
return_full_text=None,
clean_up_tokenization_spaces=False,
prefix=None,
**generate_kwargs
):
def __call__(self, text_inputs, **kwargs):
"""
Complete the prompt(s) given as inputs.
@ -105,95 +147,58 @@ class TextGenerationPipeline(Pipeline):
- **generated_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
-- The token ids of the generated text.
"""
prefix = prefix if prefix is not None else self.model.config.prefix
return_full_text = return_full_text if return_full_text is not None else self.return_full_text
return super().__call__(text_inputs, **kwargs)
if isinstance(text_inputs, str):
text_inputs = [text_inputs]
results = []
for prompt_text in text_inputs:
# Manage correct placement of the tensors
with self.device_placement():
if prefix is None and self.model.__class__.__name__ in [
"XLNetLMHeadModel",
"TransfoXLLMHeadModel",
"TFXLNetLMHeadModel",
"TFTransfoXLLMHeadModel",
]:
# For XLNet and TransformerXL we add an article to the prompt to give more state to the model.
prefix = self.XL_PREFIX
def preprocess(self, prompt_text, prefix=""):
inputs = self.tokenizer(
prefix + prompt_text, padding=False, add_special_tokens=False, return_tensors=self.framework
)
inputs["prompt_text"] = prompt_text
return inputs
if prefix:
prefix_inputs = self._parse_and_tokenize(prefix, padding=False, add_special_tokens=False)
# This impacts max_length and min_length argument that need adjusting.
prefix_length = prefix_inputs["input_ids"].shape[-1]
if generate_kwargs.get("max_length", None) is not None:
generate_kwargs["max_length"] += prefix_length
else:
generate_kwargs["max_length"] = self.model.config.max_length + prefix_length
def _forward(self, model_inputs, **generate_kwargs):
input_ids = model_inputs["input_ids"]
prompt_text = model_inputs.pop("prompt_text")
generated_sequence = self.model.generate(input_ids=input_ids, **generate_kwargs) # BS x SL
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
if generate_kwargs.get("min_length", None) is not None:
generate_kwargs["min_length"] += prefix_length
def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True):
generated_sequence = model_outputs["generated_sequence"]
input_ids = model_outputs["input_ids"]
prompt_text = model_outputs["prompt_text"]
if self.framework == "pt" and generated_sequence is not None:
generated_sequence = generated_sequence.cpu()
generated_sequence = generated_sequence.numpy().tolist()
if return_type == ReturnType.TENSORS:
record = {"generated_token_ids": generated_sequence}
elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
# Decode text
record = []
for sequence in generated_sequence:
text = self.tokenizer.decode(
sequence,
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
prefix = prefix or ""
inputs = self._parse_and_tokenize(prefix + prompt_text, padding=False, add_special_tokens=False)
# set input_ids to None to allow empty prompt
if inputs["input_ids"].shape[-1] == 0:
inputs["input_ids"] = None
inputs["attention_mask"] = None
if self.framework == "pt" and inputs["input_ids"] is not None:
inputs = self.ensure_tensor_on_device(**inputs)
input_ids = inputs["input_ids"]
# Ensure that batch size = 1 (batch generation not allowed for now)
assert (
input_ids is None or input_ids.shape[0] == 1
), "Batch generation is currently not supported. See https://github.com/huggingface/transformers/issues/3021 for more information."
output_sequences = self.model.generate(input_ids=input_ids, **generate_kwargs) # BS x SL
result = []
for generated_sequence in output_sequences:
if self.framework == "pt" and generated_sequence is not None:
generated_sequence = generated_sequence.cpu()
generated_sequence = generated_sequence.numpy().tolist()
record = {}
if return_tensors:
record["generated_token_ids"] = generated_sequence
if return_text:
# Decode text
text = self.tokenizer.decode(
generated_sequence,
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
# Remove PADDING prompt of the sequence if XLNet or Transfo-XL model is used
if input_ids is None:
prompt_length = 0
else:
prompt_length = len(
self.tokenizer.decode(
input_ids[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
)
# Remove PADDING prompt of the sequence if XLNet or Transfo-XL model is used
if input_ids is None:
prompt_length = 0
else:
prompt_length = len(
self.tokenizer.decode(
input_ids[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
)
if return_type == ReturnType.FULL_TEXT:
all_text = prompt_text + text[prompt_length:]
else:
all_text = text[prompt_length:]
if return_full_text:
all_text = prompt_text + text[prompt_length:]
else:
all_text = text[prompt_length:]
item = {"generated_text": all_text}
record.append(item)
record["generated_text"] = all_text
result.append(record)
results += [result]
if len(results) == 1:
return results[0]
return results
return record

View File

@ -1,26 +1,17 @@
import warnings
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
import numpy as np
from ..file_utils import ExplicitEnum, add_end_docstrings, is_tf_available, is_torch_available
from ..modelcard import ModelCard
from ..models.bert.tokenization_bert import BasicTokenizer
from ..tokenization_utils import PreTrainedTokenizer
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline
if TYPE_CHECKING:
from ..modeling_tf_utils import TFPreTrainedModel
from ..modeling_utils import PreTrainedModel
if is_tf_available():
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
if is_torch_available():
import torch
from ..models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
@ -104,31 +95,9 @@ class TokenClassificationPipeline(Pipeline):
default_input_names = "sequences"
def __init__(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel"],
tokenizer: PreTrainedTokenizer,
modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None,
args_parser: ArgumentHandler = TokenClassificationArgumentHandler(),
device: int = -1,
binary_output: bool = False,
ignore_labels=["O"],
task: str = "",
grouped_entities: Optional[bool] = None,
ignore_subwords: Optional[bool] = None,
aggregation_strategy: Optional[AggregationStrategy] = None,
):
super().__init__(
model=model,
tokenizer=tokenizer,
modelcard=modelcard,
framework=framework,
device=device,
binary_output=binary_output,
task=task,
)
def __init__(self, args_parser=TokenClassificationArgumentHandler(), *args, **kwargs):
self.ignore_labels = ["O"]
super().__init__(*args, **kwargs)
self.check_model_type(
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
if self.framework == "tf"
@ -137,40 +106,49 @@ class TokenClassificationPipeline(Pipeline):
self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
self._args_parser = args_parser
self.ignore_labels = ignore_labels
if aggregation_strategy is None:
aggregation_strategy = AggregationStrategy.NONE
if grouped_entities is not None or ignore_subwords is not None:
def _sanitize_parameters(
self,
ignore_labels=None,
grouped_entities: Optional[bool] = None,
ignore_subwords: Optional[bool] = None,
aggregation_strategy: Optional[AggregationStrategy] = None,
):
if grouped_entities and ignore_subwords:
aggregation_strategy = AggregationStrategy.FIRST
elif grouped_entities and not ignore_subwords:
aggregation_strategy = AggregationStrategy.SIMPLE
else:
aggregation_strategy = AggregationStrategy.NONE
postprocess_params = {}
if grouped_entities is not None or ignore_subwords is not None:
if grouped_entities and ignore_subwords:
aggregation_strategy = AggregationStrategy.FIRST
elif grouped_entities and not ignore_subwords:
aggregation_strategy = AggregationStrategy.SIMPLE
else:
aggregation_strategy = AggregationStrategy.NONE
if grouped_entities is not None:
warnings.warn(
f'`grouped_entities` is deprecated and will be removed in version v5.0.0, defaulted to `aggregation_strategy="{aggregation_strategy}"` instead.'
)
if ignore_subwords is not None:
warnings.warn(
f'`ignore_subwords` is deprecated and will be removed in version v5.0.0, defaulted to `aggregation_strategy="{aggregation_strategy}"` instead.'
)
if isinstance(aggregation_strategy, str):
aggregation_strategy = AggregationStrategy[aggregation_strategy.upper()]
if grouped_entities is not None:
warnings.warn(
f'`grouped_entities` is deprecated and will be removed in version v5.0.0, defaulted to `aggregation_strategy="{aggregation_strategy}"` instead.'
)
if ignore_subwords is not None:
warnings.warn(
f'`ignore_subwords` is deprecated and will be removed in version v5.0.0, defaulted to `aggregation_strategy="{aggregation_strategy}"` instead.'
)
if (
aggregation_strategy in {AggregationStrategy.FIRST, AggregationStrategy.MAX, AggregationStrategy.AVERAGE}
and not self.tokenizer.is_fast
):
raise ValueError(
"Slow tokenizers cannot handle subwords. Please set the `aggregation_strategy` option"
'to `"simple"` or use a fast tokenizer.'
)
self.aggregation_strategy = aggregation_strategy
if aggregation_strategy is not None:
if isinstance(aggregation_strategy, str):
aggregation_strategy = AggregationStrategy[aggregation_strategy.upper()]
if (
aggregation_strategy
in {AggregationStrategy.FIRST, AggregationStrategy.MAX, AggregationStrategy.AVERAGE}
and not self.tokenizer.is_fast
):
raise ValueError(
"Slow tokenizers cannot handle subwords. Please set the `aggregation_strategy` option"
'to `"simple"` or use a fast tokenizer.'
)
postprocess_params["aggregation_strategy"] = aggregation_strategy
if ignore_labels is not None:
postprocess_params["ignore_labels"] = ignore_labels
return {}, {}, postprocess_params
def __call__(self, inputs: Union[str, List[str]], **kwargs):
"""
@ -198,56 +176,65 @@ class TokenClassificationPipeline(Pipeline):
"""
_inputs, offset_mappings = self._args_parser(inputs, **kwargs)
self.offset_mappings = offset_mappings
answers = []
return super().__call__(inputs, **kwargs)
for i, sentence in enumerate(_inputs):
def preprocess(self, sentence):
truncation = True if self.tokenizer.model_max_length and self.tokenizer.model_max_length > 0 else False
model_inputs = self.tokenizer(
sentence,
return_attention_mask=False,
return_tensors=self.framework,
truncation=truncation,
return_special_tokens_mask=True,
return_offsets_mapping=self.tokenizer.is_fast,
)
if self.offset_mappings:
offset_mapping = self.offset_mappings[0]
model_inputs["offset_mapping"] = offset_mapping
# Manage correct placement of the tensors
with self.device_placement():
model_inputs["sentence"] = sentence
tokens = self.tokenizer(
sentence,
return_attention_mask=False,
return_tensors=self.framework,
truncation=True,
return_special_tokens_mask=True,
return_offsets_mapping=self.tokenizer.is_fast,
)
if self.tokenizer.is_fast:
offset_mapping = tokens.pop("offset_mapping").cpu().numpy()[0]
elif offset_mappings:
offset_mapping = offset_mappings[i]
else:
offset_mapping = None
return model_inputs
special_tokens_mask = tokens.pop("special_tokens_mask").cpu().numpy()[0]
def _forward(self, model_inputs):
# Forward
special_tokens_mask = model_inputs.pop("special_tokens_mask")
offset_mapping = model_inputs.pop("offset_mapping", None)
sentence = model_inputs.pop("sentence")
if self.framework == "tf":
outputs = self.model(model_inputs.data)[0][0].numpy()
else:
outputs = self.model(**model_inputs)[0][0].numpy()
return {
"outputs": outputs,
"special_tokens_mask": special_tokens_mask,
"offset_mapping": offset_mapping,
"sentence": sentence,
**model_inputs,
}
# Forward
if self.framework == "tf":
entities = self.model(tokens.data)[0][0].numpy()
input_ids = tokens["input_ids"].numpy()[0]
else:
with torch.no_grad():
tokens = self.ensure_tensor_on_device(**tokens)
entities = self.model(**tokens)[0][0].cpu().numpy()
input_ids = tokens["input_ids"].cpu().numpy()[0]
def postprocess(self, model_outputs, aggregation_strategy=AggregationStrategy.NONE):
outputs = model_outputs["outputs"]
sentence = model_outputs["sentence"]
input_ids = model_outputs["input_ids"][0]
offset_mapping = model_outputs["offset_mapping"][0] if model_outputs["offset_mapping"] is not None else None
special_tokens_mask = model_outputs["special_tokens_mask"][0].numpy()
scores = np.exp(entities) / np.exp(entities).sum(-1, keepdims=True)
pre_entities = self.gather_pre_entities(sentence, input_ids, scores, offset_mapping, special_tokens_mask)
grouped_entities = self.aggregate(pre_entities, self.aggregation_strategy)
# Filter anything that is in self.ignore_labels
entities = [
entity
for entity in grouped_entities
if entity.get("entity", None) not in self.ignore_labels
and entity.get("entity_group", None) not in self.ignore_labels
]
answers.append(entities)
if len(answers) == 1:
return answers[0]
return answers
scores = np.exp(outputs) / np.exp(outputs).sum(-1, keepdims=True)
pre_entities = self.gather_pre_entities(
sentence, input_ids, scores, offset_mapping, special_tokens_mask, aggregation_strategy
)
grouped_entities = self.aggregate(pre_entities, aggregation_strategy)
# Filter anything that is in self.ignore_labels
entities = [
entity
for entity in grouped_entities
if entity.get("entity", None) not in self.ignore_labels
and entity.get("entity_group", None) not in self.ignore_labels
]
return entities
def gather_pre_entities(
self,
@ -256,6 +243,7 @@ class TokenClassificationPipeline(Pipeline):
scores: np.ndarray,
offset_mapping: Optional[List[Tuple[int, int]]],
special_tokens_mask: np.ndarray,
aggregation_strategy: AggregationStrategy,
) -> List[dict]:
"""Fuse various numpy arrays into dicts with all the information needed for aggregation"""
pre_entities = []
@ -269,6 +257,12 @@ class TokenClassificationPipeline(Pipeline):
word = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx]))
if offset_mapping is not None:
start_ind, end_ind = offset_mapping[idx]
if self.framework == "pt":
start_ind = start_ind.item()
end_ind = end_ind.item()
else:
start_ind = int(start_ind.numpy())
end_ind = int(end_ind.numpy())
word_ref = sentence[start_ind:end_ind]
if getattr(self.tokenizer._tokenizer.model, "continuing_subword_prefix", None):
# This is a BPE, word aware tokenizer, there is a correct way
@ -276,7 +270,7 @@ class TokenClassificationPipeline(Pipeline):
is_subword = len(word) != len(word_ref)
else:
# This is a fallback heuristic. This will fail most likely on any kind of text + punctuation mixtures that will be considered "words". Non word aware models cannot do better than this unfortunately.
if self.aggregation_strategy in {
if aggregation_strategy in {
AggregationStrategy.FIRST,
AggregationStrategy.AVERAGE,
AggregationStrategy.MAX,
@ -362,10 +356,11 @@ class TokenClassificationPipeline(Pipeline):
Example: micro|soft| com|pany| B-ENT I-NAME I-ENT I-ENT will be rewritten with first strategy as microsoft|
company| B-ENT I-ENT
"""
assert aggregation_strategy not in {
if aggregation_strategy in {
AggregationStrategy.NONE,
AggregationStrategy.SIMPLE,
}, "NONE and SIMPLE strategies are invalid"
}:
raise ValueError("NONE and SIMPLE strategies are invalid for word aggregation")
word_entities = []
word_group = None

View File

@ -2,15 +2,12 @@ from typing import List, Union
import numpy as np
from ..file_utils import add_end_docstrings, is_torch_available
from ..file_utils import add_end_docstrings
from ..tokenization_utils import TruncationStrategy
from ..utils import logging
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
@ -22,7 +19,7 @@ class ZeroShotClassificationArgumentHandler(ArgumentHandler):
def _parse_labels(self, labels):
if isinstance(labels, str):
labels = [label.strip() for label in labels.split(",")]
labels = [label.strip() for label in labels.split(",") if label.strip()]
return labels
def __call__(self, sequences, labels, hypothesis_template):
@ -38,13 +35,12 @@ class ZeroShotClassificationArgumentHandler(ArgumentHandler):
if isinstance(sequences, str):
sequences = [sequences]
labels = self._parse_labels(labels)
sequence_pairs = []
for sequence in sequences:
sequence_pairs.extend([[sequence, hypothesis_template.format(label)] for label in labels])
return sequence_pairs
return sequence_pairs, sequences
@add_end_docstrings(PIPELINE_INIT_ARGS)
@ -66,8 +62,8 @@ class ZeroShotClassificationPipeline(Pipeline):
"""
def __init__(self, args_parser=ZeroShotClassificationArgumentHandler(), *args, **kwargs):
super().__init__(*args, **kwargs)
self._args_parser = args_parser
super().__init__(*args, **kwargs)
if self.entailment_id == -1:
logger.warning(
"Failed to determine 'entailment' label id from the label2id mapping in the model config. Setting to "
@ -82,19 +78,11 @@ class ZeroShotClassificationPipeline(Pipeline):
return -1
def _parse_and_tokenize(
self,
sequences,
candidate_labels,
hypothesis_template,
padding=True,
add_special_tokens=True,
truncation=TruncationStrategy.ONLY_FIRST,
**kwargs
self, sequence_pairs, padding=True, add_special_tokens=True, truncation=TruncationStrategy.ONLY_FIRST, **kwargs
):
"""
Parse arguments and tokenize only_first so that hypothesis (label) is not truncated
"""
sequence_pairs = self._args_parser(sequences, candidate_labels, hypothesis_template)
return_tensors = self.framework
if getattr(self.tokenizer, "pad_token", None) is None:
# XXX some tokenizers do not have a padding token, we use simple lists
@ -141,55 +129,27 @@ class ZeroShotClassificationPipeline(Pipeline):
return inputs
def _forward(self, inputs, return_tensors=False):
"""
Internal framework specific forward dispatching
def _sanitize_parameters(self, **kwargs):
if kwargs.get("multi_class", None) is not None:
kwargs["multi_label"] = kwargs["multi_class"]
logger.warning(
"The `multi_class` argument has been deprecated and renamed to `multi_label`. "
"`multi_class` will be removed in a future version of Transformers."
)
preprocess_params = {}
if "candidate_labels" in kwargs:
preprocess_params["candidate_labels"] = self._args_parser._parse_labels(kwargs["candidate_labels"])
if "hypothesis_template" in kwargs:
preprocess_params["hypothesis_template"] = kwargs["hypothesis_template"]
Args:
inputs: dict holding all the keyword arguments for required by the model forward method.
return_tensors: Whether to return native framework (pt/tf) tensors rather than numpy array
Returns:
Numpy array
"""
# Encode for forward
with self.device_placement():
if self.framework == "tf":
if isinstance(inputs, list):
predictions = []
for input_ in inputs:
prediction = self.model(input_.data, training=False)[0]
predictions.append(prediction)
else:
predictions = self.model(inputs.data, training=False)[0]
else:
with torch.no_grad():
if isinstance(inputs, list):
predictions = []
for input_ in inputs:
model_input = self.ensure_tensor_on_device(**input_)
prediction = self.model(**model_input)[0].cpu()
predictions.append(prediction)
else:
inputs = self.ensure_tensor_on_device(**inputs)
predictions = self.model(**inputs)[0].cpu()
if return_tensors:
return predictions
else:
if isinstance(predictions, list):
predictions = np.array([p.numpy() for p in predictions])
else:
predictions = predictions.numpy()
return predictions
postprocess_params = {}
if "multi_label" in kwargs:
postprocess_params["multi_label"] = kwargs["multi_label"]
return preprocess_params, {}, postprocess_params
def __call__(
self,
sequences: Union[str, List[str]],
candidate_labels,
hypothesis_template="This example is {}.",
multi_label=False,
**kwargs,
):
"""
@ -222,53 +182,78 @@ class ZeroShotClassificationPipeline(Pipeline):
- **labels** (:obj:`List[str]`) -- The labels sorted by order of likelihood.
- **scores** (:obj:`List[float]`) -- The probabilities for each of the labels.
"""
if "multi_class" in kwargs and kwargs["multi_class"] is not None:
multi_label = kwargs.pop("multi_class")
logger.warning(
"The `multi_class` argument has been deprecated and renamed to `multi_label`. "
"`multi_class` will be removed in a future version of Transformers."
)
if sequences and isinstance(sequences, str):
sequences = [sequences]
result = super().__call__(sequences, **kwargs)
if len(result) == 1:
return result[0]
return result
outputs = super().__call__(sequences, candidate_labels, hypothesis_template)
if isinstance(outputs, list):
# XXX: Some tokenizers cannot handle batching because they don't
# have pad_token, so outputs will be a list, however, because outputs
# is only n logits and sequence_length is not present anymore, we
# can recreate a tensor out of outputs.
outputs = np.array(outputs)
num_sequences = len(sequences)
candidate_labels = self._args_parser._parse_labels(candidate_labels)
reshaped_outputs = outputs.reshape((num_sequences, len(candidate_labels), -1))
def preprocess(self, inputs, candidate_labels=None, hypothesis_template="This example is {}."):
sequence_pairs, sequences = self._args_parser(inputs, candidate_labels, hypothesis_template)
model_inputs = self._parse_and_tokenize(sequence_pairs)
if len(candidate_labels) == 1:
multi_label = True
prepared_inputs = {
"candidate_labels": candidate_labels,
"sequences": sequences,
"inputs": model_inputs,
}
return prepared_inputs
if not multi_label:
# softmax the "entailment" logits over all candidate labels
entail_logits = reshaped_outputs[..., self.entailment_id]
scores = np.exp(entail_logits) / np.exp(entail_logits).sum(-1, keepdims=True)
def _forward(self, inputs):
candidate_labels = inputs["candidate_labels"]
sequences = inputs["sequences"]
model_inputs = inputs["inputs"]
if isinstance(model_inputs, list):
outputs = []
for input_ in model_inputs:
prediction = self.model(**input_)[0].cpu()
outputs.append(prediction)
else:
outputs = self.model(**model_inputs)
model_outputs = {"candidate_labels": candidate_labels, "sequences": sequences, "outputs": outputs}
return model_outputs
def postprocess(self, model_outputs, multi_label=False):
candidate_labels = model_outputs["candidate_labels"]
sequences = model_outputs["sequences"]
outputs = model_outputs["outputs"]
if self.framework == "pt":
if isinstance(outputs, list):
logits = np.concatenate([output.cpu().numpy() for output in outputs], axis=0)
else:
logits = outputs["logits"].cpu().numpy()
else:
if isinstance(outputs, list):
logits = np.concatenate([output.numpy() for output in outputs], axis=0)
else:
logits = outputs["logits"].numpy()
N = logits.shape[0]
n = len(candidate_labels)
num_sequences = N // n
reshaped_outputs = logits.reshape((num_sequences, n, -1))
if multi_label or len(candidate_labels) == 1:
# softmax over the entailment vs. contradiction dim for each label independently
entailment_id = self.entailment_id
contradiction_id = -1 if entailment_id == 0 else 0
entail_contr_logits = reshaped_outputs[..., [contradiction_id, entailment_id]]
scores = np.exp(entail_contr_logits) / np.exp(entail_contr_logits).sum(-1, keepdims=True)
scores = scores[..., 1]
else:
# softmax the "entailment" logits over all candidate labels
entail_logits = reshaped_outputs[..., self.entailment_id]
scores = np.exp(entail_logits) / np.exp(entail_logits).sum(-1, keepdims=True)
result = []
for iseq in range(num_sequences):
top_inds = list(reversed(scores[iseq].argsort()))
result.append(
{
"sequence": sequences if isinstance(sequences, str) else sequences[iseq],
"sequence": sequences[iseq],
"labels": [candidate_labels[i] for i in top_inds],
"scores": scores[iseq][top_inds].tolist(),
"scores": scores[iseq, top_inds].tolist(),
}
)
if len(result) == 1:
return result[0]
return result

View File

@ -1343,6 +1343,8 @@ def nested_simplify(obj, decimals=3):
return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()}
elif isinstance(obj, (str, int, np.int64)):
return obj
elif obj is None:
return obj
elif is_torch_available() and isinstance(obj, torch.Tensor):
return nested_simplify(obj.tolist(), decimals)
elif is_tf_available() and tf.is_tensor(obj):

View File

@ -15,11 +15,13 @@
import importlib
import logging
import string
import unittest
from abc import abstractmethod
from functools import lru_cache
from unittest import skipIf
from transformers import FEATURE_EXTRACTOR_MAPPING, TOKENIZER_MAPPING, AutoFeatureExtractor, AutoTokenizer
from transformers import FEATURE_EXTRACTOR_MAPPING, TOKENIZER_MAPPING, AutoFeatureExtractor, AutoTokenizer, pipeline
from transformers.testing_utils import is_pipeline_test, require_torch
logger = logging.getLogger(__name__)
@ -177,3 +179,30 @@ class PipelineTestCaseMeta(type):
dct["test_small_model_tf"] = dct.get("test_small_model_tf", inner)
return type.__new__(mcs, name, bases, dct)
@is_pipeline_test
class CommonPipelineTest(unittest.TestCase):
@require_torch
def test_pipeline_iteration(self):
from torch.utils.data import Dataset
class MyDataset(Dataset):
data = [
"This is a test",
"This restaurant is great",
"This restaurant is awful",
]
def __len__(self):
return 3
def __getitem__(self, i):
return self.data[i]
text_classifier = pipeline(
task="text-classification", model="Narsil/tiny-distilbert-sequence-classification", framework="pt"
)
dataset = MyDataset()
for output in text_classifier(dataset):
self.assertEqual(output, {"label": ANY(str), "score": ANY(float)})

View File

@ -187,24 +187,15 @@ class ConversationalPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer)
conversation_1 = Conversation("hello")
inputs = conversation_agent._parse_and_tokenize([conversation_1])
inputs = conversation_agent.preprocess(conversation_1)
self.assertEqual(inputs["input_ids"].tolist(), [[31373, 50256]])
conversation_2 = Conversation("how are you ?", past_user_inputs=["hello"], generated_responses=["Hi there!"])
inputs = conversation_agent._parse_and_tokenize([conversation_2])
inputs = conversation_agent.preprocess(conversation_2)
self.assertEqual(
inputs["input_ids"].tolist(), [[31373, 50256, 17250, 612, 0, 50256, 4919, 389, 345, 5633, 50256]]
)
inputs = conversation_agent._parse_and_tokenize([conversation_1, conversation_2])
self.assertEqual(
inputs["input_ids"].tolist(),
[
[31373, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256],
[31373, 50256, 17250, 612, 0, 50256, 4919, 389, 345, 5633, 50256],
],
)
@require_torch
@slow
def test_integration_torch_conversation_blenderbot_400M_input_ids(self):
@ -214,7 +205,7 @@ class ConversationalPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
# test1
conversation_1 = Conversation("hello")
inputs = conversation_agent._parse_and_tokenize([conversation_1])
inputs = conversation_agent.preprocess(conversation_1)
self.assertEqual(inputs["input_ids"].tolist(), [[1710, 86, 2]])
# test2
@ -225,7 +216,7 @@ class ConversationalPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
" Do you like lasagne? It is a traditional Italian dish consisting of a shepherd's pie."
],
)
inputs = conversation_agent._parse_and_tokenize([conversation_1])
inputs = conversation_agent.preprocess(conversation_1)
self.assertEqual(
inputs["input_ids"].tolist(),
[
@ -271,7 +262,7 @@ class ConversationalPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
964,
21,
2, # EOS
]
],
],
)

View File

@ -91,6 +91,8 @@ class FeatureExtractionPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
shape = self.get_shape(outputs)
self.assertEqual(shape[0], 1)
outputs = feature_extractor(["This is a test", "Another test"])
# If we send too small input
# there's a bug within FunnelModel (output with shape [1, 4, 2, 1] doesn't match the broadcast shape [1, 4, 2, 2])
outputs = feature_extractor(["This is a test", "Another longer test"])
shape = self.get_shape(outputs)
self.assertEqual(shape[0], 2)

View File

@ -186,7 +186,7 @@ class FillMaskPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
],
)
outputs = fill_masker([f"This is a {tokenizer.mask_token}", f"Another {tokenizer.mask_token}"])
outputs = fill_masker([f"This is a {tokenizer.mask_token}", f"Another {tokenizer.mask_token} great test."])
self.assertEqual(
outputs,
[

View File

@ -116,8 +116,8 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 266, "ymin": 200, "xmax": 799, "ymax": 599}},
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 266, "ymin": 200, "xmax": 799, "ymax": 599}},
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 160, "ymin": 120, "xmax": 480, "ymax": 359}},
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 160, "ymin": 120, "xmax": 480, "ymax": 359}},
],
)
@ -133,12 +133,12 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase
nested_simplify(outputs, decimals=4),
[
[
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 266, "ymin": 200, "xmax": 799, "ymax": 599}},
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 266, "ymin": 200, "xmax": 799, "ymax": 599}},
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 160, "ymin": 120, "xmax": 480, "ymax": 359}},
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 160, "ymin": 120, "xmax": 480, "ymax": 359}},
],
[
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 266, "ymin": 200, "xmax": 799, "ymax": 599}},
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 266, "ymin": 200, "xmax": 799, "ymax": 599}},
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 160, "ymin": 120, "xmax": 480, "ymax": 359}},
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 160, "ymin": 120, "xmax": 480, "ymax": 359}},
],
],
)
@ -156,11 +156,11 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
{"score": 0.9982, "label": "remote", "box": {"xmin": 66, "ymin": 118, "xmax": 292, "ymax": 196}},
{"score": 0.9960, "label": "remote", "box": {"xmin": 555, "ymin": 120, "xmax": 613, "ymax": 312}},
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 1065, "ymax": 789}},
{"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}},
{"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}},
{"score": 0.9982, "label": "remote", "box": {"xmin": 40, "ymin": 70, "xmax": 175, "ymax": 117}},
{"score": 0.9960, "label": "remote", "box": {"xmin": 333, "ymin": 72, "xmax": 368, "ymax": 187}},
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 639, "ymax": 473}},
{"score": 0.9988, "label": "cat", "box": {"xmin": 13, "ymin": 52, "xmax": 314, "ymax": 470}},
{"score": 0.9987, "label": "cat", "box": {"xmin": 345, "ymin": 23, "xmax": 640, "ymax": 368}},
],
)
@ -174,18 +174,18 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase
nested_simplify(outputs, decimals=4),
[
[
{"score": 0.9982, "label": "remote", "box": {"xmin": 66, "ymin": 118, "xmax": 292, "ymax": 196}},
{"score": 0.9960, "label": "remote", "box": {"xmin": 555, "ymin": 120, "xmax": 613, "ymax": 312}},
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 1065, "ymax": 789}},
{"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}},
{"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}},
{"score": 0.9982, "label": "remote", "box": {"xmin": 40, "ymin": 70, "xmax": 175, "ymax": 117}},
{"score": 0.9960, "label": "remote", "box": {"xmin": 333, "ymin": 72, "xmax": 368, "ymax": 187}},
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 639, "ymax": 473}},
{"score": 0.9988, "label": "cat", "box": {"xmin": 13, "ymin": 52, "xmax": 314, "ymax": 470}},
{"score": 0.9987, "label": "cat", "box": {"xmin": 345, "ymin": 23, "xmax": 640, "ymax": 368}},
],
[
{"score": 0.9982, "label": "remote", "box": {"xmin": 66, "ymin": 118, "xmax": 292, "ymax": 196}},
{"score": 0.9960, "label": "remote", "box": {"xmin": 555, "ymin": 120, "xmax": 613, "ymax": 312}},
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 1065, "ymax": 789}},
{"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}},
{"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}},
{"score": 0.9982, "label": "remote", "box": {"xmin": 40, "ymin": 70, "xmax": 175, "ymax": 117}},
{"score": 0.9960, "label": "remote", "box": {"xmin": 333, "ymin": 72, "xmax": 368, "ymax": 187}},
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 639, "ymax": 473}},
{"score": 0.9988, "label": "cat", "box": {"xmin": 13, "ymin": 52, "xmax": 314, "ymax": 470}},
{"score": 0.9987, "label": "cat", "box": {"xmin": 345, "ymin": 23, "xmax": 640, "ymax": 368}},
],
],
)
@ -201,11 +201,11 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
{"score": 0.9982, "label": "remote", "box": {"xmin": 66, "ymin": 118, "xmax": 292, "ymax": 196}},
{"score": 0.9960, "label": "remote", "box": {"xmin": 555, "ymin": 120, "xmax": 613, "ymax": 312}},
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 1065, "ymax": 789}},
{"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}},
{"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}},
{"score": 0.9982, "label": "remote", "box": {"xmin": 40, "ymin": 70, "xmax": 175, "ymax": 117}},
{"score": 0.9960, "label": "remote", "box": {"xmin": 333, "ymin": 72, "xmax": 368, "ymax": 187}},
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 639, "ymax": 473}},
{"score": 0.9988, "label": "cat", "box": {"xmin": 13, "ymin": 52, "xmax": 314, "ymax": 470}},
{"score": 0.9987, "label": "cat", "box": {"xmin": 345, "ymin": 23, "xmax": 640, "ymax": 368}},
],
)
@ -219,18 +219,18 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase
nested_simplify(outputs, decimals=4),
[
[
{"score": 0.9982, "label": "remote", "box": {"xmin": 66, "ymin": 118, "xmax": 292, "ymax": 196}},
{"score": 0.9960, "label": "remote", "box": {"xmin": 555, "ymin": 120, "xmax": 613, "ymax": 312}},
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 1065, "ymax": 789}},
{"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}},
{"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}},
{"score": 0.9982, "label": "remote", "box": {"xmin": 40, "ymin": 70, "xmax": 175, "ymax": 117}},
{"score": 0.9960, "label": "remote", "box": {"xmin": 333, "ymin": 72, "xmax": 368, "ymax": 187}},
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 639, "ymax": 473}},
{"score": 0.9988, "label": "cat", "box": {"xmin": 13, "ymin": 52, "xmax": 314, "ymax": 470}},
{"score": 0.9987, "label": "cat", "box": {"xmin": 345, "ymin": 23, "xmax": 640, "ymax": 368}},
],
[
{"score": 0.9982, "label": "remote", "box": {"xmin": 66, "ymin": 118, "xmax": 292, "ymax": 196}},
{"score": 0.9960, "label": "remote", "box": {"xmin": 555, "ymin": 120, "xmax": 613, "ymax": 312}},
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 1065, "ymax": 789}},
{"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}},
{"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}},
{"score": 0.9982, "label": "remote", "box": {"xmin": 40, "ymin": 70, "xmax": 175, "ymax": 117}},
{"score": 0.9960, "label": "remote", "box": {"xmin": 333, "ymin": 72, "xmax": 368, "ymax": 187}},
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 639, "ymax": 473}},
{"score": 0.9988, "label": "cat", "box": {"xmin": 13, "ymin": 52, "xmax": 314, "ymax": 470}},
{"score": 0.9987, "label": "cat", "box": {"xmin": 345, "ymin": 23, "xmax": 640, "ymax": 368}},
],
],
)
@ -247,7 +247,7 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
{"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}},
{"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}},
{"score": 0.9988, "label": "cat", "box": {"xmin": 13, "ymin": 52, "xmax": 314, "ymax": 470}},
{"score": 0.9987, "label": "cat", "box": {"xmin": 345, "ymin": 23, "xmax": 640, "ymax": 368}},
],
)

View File

@ -96,7 +96,7 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
def run_aggregation_strategy(self, model, tokenizer):
token_classifier = TokenClassificationPipeline(model=model, tokenizer=tokenizer, aggregation_strategy="simple")
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.SIMPLE)
self.assertEqual(token_classifier._postprocess_params["aggregation_strategy"], AggregationStrategy.SIMPLE)
outputs = token_classifier("A simple string")
self.assertIsInstance(outputs, list)
n = len(outputs)
@ -115,7 +115,7 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
)
token_classifier = TokenClassificationPipeline(model=model, tokenizer=tokenizer, aggregation_strategy="first")
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.FIRST)
self.assertEqual(token_classifier._postprocess_params["aggregation_strategy"], AggregationStrategy.FIRST)
outputs = token_classifier("A simple string")
self.assertIsInstance(outputs, list)
n = len(outputs)
@ -134,7 +134,7 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
)
token_classifier = TokenClassificationPipeline(model=model, tokenizer=tokenizer, aggregation_strategy="max")
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.MAX)
self.assertEqual(token_classifier._postprocess_params["aggregation_strategy"], AggregationStrategy.MAX)
outputs = token_classifier("A simple string")
self.assertIsInstance(outputs, list)
n = len(outputs)
@ -155,7 +155,7 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
token_classifier = TokenClassificationPipeline(
model=model, tokenizer=tokenizer, aggregation_strategy="average"
)
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.AVERAGE)
self.assertEqual(token_classifier._postprocess_params["aggregation_strategy"], AggregationStrategy.AVERAGE)
outputs = token_classifier("A simple string")
self.assertIsInstance(outputs, list)
n = len(outputs)
@ -175,12 +175,12 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
with self.assertWarns(UserWarning):
token_classifier = pipeline(task="ner", model=model, tokenizer=tokenizer, grouped_entities=True)
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.SIMPLE)
self.assertEqual(token_classifier._postprocess_params["aggregation_strategy"], AggregationStrategy.SIMPLE)
with self.assertWarns(UserWarning):
token_classifier = pipeline(
task="ner", model=model, tokenizer=tokenizer, grouped_entities=True, ignore_subwords=True
)
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.FIRST)
self.assertEqual(token_classifier._postprocess_params["aggregation_strategy"], AggregationStrategy.FIRST)
@require_torch
@slow
@ -533,7 +533,12 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
scores = np.array([[1, 0, 0], [0.1, 0.3, 0.6], [0.8, 0.1, 0.1]])
pre_entities = token_classifier.gather_pre_entities(
sentence, input_ids, scores, offset_mapping, special_tokens_mask
sentence,
input_ids,
scores,
offset_mapping,
special_tokens_mask,
aggregation_strategy=AggregationStrategy.NONE,
)
self.assertEqual(
nested_simplify(pre_entities),
@ -570,6 +575,20 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
],
)
@require_torch
def test_no_offset_tokenizer(self):
model_name = "Narsil/small2"
tokenizer = AutoTokenizer.from_pretrained("Narsil/small2", use_fast=False)
token_classifier = pipeline(task="token-classification", model=model_name, tokenizer=tokenizer, framework="pt")
outputs = token_classifier("This is a test !")
self.assertEqual(
nested_simplify(outputs),
[
{"entity": "I-MISC", "score": 0.115, "index": 1, "word": "this", "start": None, "end": None},
{"entity": "I-MISC", "score": 0.115, "index": 2, "word": "is", "start": None, "end": None},
],
)
@require_torch
def test_small_model_pt(self):
model_name = "Narsil/small2"

View File

@ -108,8 +108,8 @@ class TranslationNewFormatPipelineTests(unittest.TestCase):
# but we do for this one
translator = pipeline(task="translation_en_to_de")
self.assertEquals(translator.src_lang, "en")
self.assertEquals(translator.tgt_lang, "de")
self.assertEqual(translator._preprocess_params["src_lang"], "en")
self.assertEqual(translator._preprocess_params["tgt_lang"], "de")
@require_torch
@slow
@ -137,8 +137,8 @@ class TranslationNewFormatPipelineTests(unittest.TestCase):
def test_translation_on_odd_language(self):
model = "patrickvonplaten/t5-tiny-random"
translator = pipeline(task="translation_cn_to_ar", model=model)
self.assertEquals(translator.src_lang, "cn")
self.assertEquals(translator.tgt_lang, "ar")
self.assertEqual(translator._preprocess_params["src_lang"], "cn")
self.assertEqual(translator._preprocess_params["tgt_lang"], "ar")
@require_torch
def test_translation_default_language_selection(self):
@ -146,8 +146,8 @@ class TranslationNewFormatPipelineTests(unittest.TestCase):
with pytest.warns(UserWarning, match=r".*translation_en_to_de.*"):
translator = pipeline(task="translation", model=model)
self.assertEqual(translator.task, "translation_en_to_de")
self.assertEqual(translator.src_lang, "en")
self.assertEqual(translator.tgt_lang, "de")
self.assertEqual(translator._preprocess_params["src_lang"], "en")
self.assertEqual(translator._preprocess_params["tgt_lang"], "de")
@require_torch
def test_translation_with_no_language_no_model_fails(self):