Binding pipelines to the cli.

This commit is contained in:
Morgan Funtowicz 2019-12-15 01:37:16 +01:00
parent 0b51532ce9
commit f1971bf303
3 changed files with 152 additions and 10 deletions

View File

@ -2,6 +2,7 @@
from argparse import ArgumentParser
from transformers.commands.download import DownloadCommand
from transformers.commands.run import RunCommand
from transformers.commands.serving import ServeCommand
from transformers.commands.user import UserCommands
from transformers.commands.train import TrainCommand
@ -14,9 +15,10 @@ if __name__ == '__main__':
# Register commands
ConvertCommand.register_subcommand(commands_parser)
DownloadCommand.register_subcommand(commands_parser)
RunCommand.register_subcommand(commands_parser)
ServeCommand.register_subcommand(commands_parser)
UserCommands.register_subcommand(commands_parser)
TrainCommand.register_subcommand(commands_parser)
UserCommands.register_subcommand(commands_parser)
# Let's go
args = parser.parse_args()

View File

@ -0,0 +1,56 @@
from argparse import ArgumentParser
from transformers.commands import BaseTransformersCLICommand
from transformers.pipelines import pipeline, Pipeline, PipelineDataFormat, SUPPORTED_TASKS
def try_infer_format_from_ext(path: str):
for ext in PipelineDataFormat.SUPPORTED_FORMATS:
if path.endswith(ext):
return ext
raise Exception(
'Unable to determine file format from file extension {}. '
'Please provide the format through --format {}'.format(path, PipelineDataFormat.SUPPORTED_FORMATS)
)
def run_command_factory(args):
nlp = pipeline(task=args.task, model=args.model, tokenizer=args.tokenizer)
format = try_infer_format_from_ext(args.input) if args.format == 'infer' else args.format
reader = PipelineDataFormat.from_str(format, args.output, args.input, args.column)
return RunCommand(nlp, reader)
class RunCommand(BaseTransformersCLICommand):
def __init__(self, nlp: Pipeline, reader: PipelineDataFormat):
self._nlp = nlp
self._reader = reader
@staticmethod
def register_subcommand(parser: ArgumentParser):
run_parser = parser.add_parser('run', help="Run a pipeline through the CLI")
run_parser.add_argument('--task', choices=SUPPORTED_TASKS.keys(), help='Task to run')
run_parser.add_argument('--model', type=str, required=True, help='Name or path to the model to instantiate.')
run_parser.add_argument('--tokenizer', type=str, help='Name of the tokenizer to use. (default: same as the model name)')
run_parser.add_argument('--column', type=str, required=True, help='Name of the column to use as input. (For multi columns input as QA use column1,columns2)')
run_parser.add_argument('--format', type=str, default='infer', choices=PipelineDataFormat.SUPPORTED_FORMATS, help='Input format to read from')
run_parser.add_argument('--input', type=str, required=True, help='Path to the file to use for inference')
run_parser.add_argument('--output', type=str, required=True, help='Path to the file that will be used post to write results.')
run_parser.add_argument('kwargs', nargs='*', help='Arguments to forward to the file format reader')
run_parser.set_defaults(func=run_command_factory)
def run(self):
nlp, output = self._nlp, []
for entry in self._reader:
if self._reader.is_multi_columns:
output += [nlp(**entry)]
else:
output += [nlp(entry)]
# Saving data
self._reader.save(output)

View File

@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals
import csv
import json
import os
from abc import ABC, abstractmethod
from itertools import groupby
@ -25,11 +27,13 @@ from transformers import AutoTokenizer, PreTrainedTokenizer, PretrainedConfig, \
SquadExample, squad_convert_examples_to_features, is_tf_available, is_torch_available, logger
if is_tf_available():
from transformers import TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering, TFAutoModelForTokenClassification
from transformers import TFAutoModel, TFAutoModelForSequenceClassification, \
TFAutoModelForQuestionAnswering, TFAutoModelForTokenClassification
if is_torch_available():
import torch
from transformers import AutoModelForSequenceClassification, AutoModelForQuestionAnswering, AutoModelForTokenClassification
from transformers import AutoModel, AutoModelForSequenceClassification, \
AutoModelForQuestionAnswering, AutoModelForTokenClassification
class Pipeline(ABC):
@ -58,6 +62,84 @@ class Pipeline(ABC):
raise NotImplementedError()
class PipelineDataFormat:
SUPPORTED_FORMATS = ['json', 'csv']
def __init__(self, output: str, path: str, column: str):
self.output = output
self.path = path
self.column = column.split(',')
self.is_multi_columns = len(self.column) > 1
if self.is_multi_columns:
self.column = [tuple(c.split('=')) if '=' in c else (c, c) for c in self.column]
from os.path import abspath, exists
if exists(abspath(self.output)):
raise OSError('{} already exists on disk'.format(self.output))
if not exists(abspath(self.path)):
raise OSError('{} doesnt exist on disk'.format(self.path))
@abstractmethod
def __iter__(self):
raise NotImplementedError()
@abstractmethod
def save(self, data: dict):
raise NotImplementedError()
@staticmethod
def from_str(name: str, output: str, path: str, column: str):
if name == 'json':
return JsonPipelineDataFormat(output, path, column)
elif name == 'csv':
return CsvPipelineDataFormat(output, path, column)
else:
raise KeyError('Unknown reader {} (Available reader are json/csv)'.format(name))
class CsvPipelineDataFormat(PipelineDataFormat):
def __init__(self, output: str, path: str, column: str):
super().__init__(output, path, column)
def __iter__(self):
with open(self.path, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
if self.is_multi_columns:
yield {k: row[c] for k, c in self.column}
else:
yield row[self.column]
def save(self, data: List[dict]):
with open(self.output, 'w') as f:
if len(data) > 0:
writer = csv.DictWriter(f, list(data[0].keys()))
writer.writeheader()
writer.writerows(data)
class JsonPipelineDataFormat(PipelineDataFormat):
def __init__(self, output: str, path: str, column: str):
super().__init__(output, path, column)
with open(path, 'r') as f:
self._entries = json.load(f)
def __iter__(self):
for entry in self._entries:
if self.is_multi_columns:
yield {k: entry[c] for k, c in self.column}
else:
yield entry[self.column]
def save(self, data: dict):
with open(self.output, 'w') as f:
json.dump(data, f)
class FeatureExtractionPipeline(Pipeline):
def __call__(self, *texts, **kwargs):
@ -127,7 +209,7 @@ class NerPipeline(Pipeline):
label_idx = score.argmax()
answer += [{
'word': words[idx - 1], 'score': score[label_idx], 'entity': self.model.config.id2label[label_idx]
'word': words[idx - 1], 'score': score[label_idx].item(), 'entity': self.model.config.id2label[label_idx]
}]
# Update token start
@ -270,16 +352,18 @@ class QuestionAnsweringPipeline(Pipeline):
char_to_word = np.array(example.char_to_word_offset)
# Convert the answer (tokens) back to the original text
answers += [[
answers += [
{
'score': score,
'start': np.where(char_to_word == feature.token_to_orig_map[s])[0][0],
'end': np.where(char_to_word == feature.token_to_orig_map[e])[0][-1],
'score': score.item(),
'start': np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(),
'end': np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(),
'answer': ' '.join(example.doc_tokens[feature.token_to_orig_map[s]: feature.token_to_orig_map[e] + 1])
}
for s, e, score in zip(starts, ends, scores)
]]
]
if len(answers) == 1:
return answers[0]
return answers
def decode(self, start: np.ndarray, end: np.ndarray, topk: int, max_answer_len: int) -> Tuple:
@ -363,7 +447,7 @@ def pipeline(task: str, model, config: Optional[PretrainedConfig] = None, tokeni
Utility factory method to build pipeline.
"""
# Try to infer tokenizer from model name (if provided as str)
if not isinstance(tokenizer, PreTrainedTokenizer):
if tokenizer is None:
if not isinstance(model, str):
# Impossible to guest what is the right tokenizer here
raise Exception('Tokenizer cannot be None if provided model is a PreTrainedModel instance')