Pipeline for Text Generation: GenerationPipeline (#3758)
* Add GenerationPipeline * Fix parameter names * Correct parameter __call__ parameters * Add model type attribute and correct function calls for prepare_input * Take out trailing commas from init attributes * Remove unnecessary tokenization line * Implement support for multiple text inputs * Apply generation support for multiple input text prompts * Take out tensor coersion * Take out batch index * Add text prompt to return sequence * Squeeze token tensore before decoding * Return only a single list of sequences if only one prompt was used * Correct results variable name * Add GenerationPipeline to SUPPORTED_TASKS with the alias , initalized w GPT2 * Registedred AutoModelWithLMHead for both pt and t * Update docstring for GenerationPipeline * Add kwargs parameter to mode.generate * Take out kwargs parameter after all * Add generation pipeline example in pipeline docstring * Fix max length by squeezing tokens tensor * Apply ensure_tensor_on_device to pytorch tensor * Include generation step in torch.no_grad * Take out input from prepare_xlm_input and set 'en' as default xlm_language * Apply framework specific encoding during prepare_input * Format w make style * Move GenerationPipeline import to follow proper import sorting * Take out training comma from generation dict * Apply requested changes * Change name to TextGenerationPipeline * Apply TextGenerationPipeline rename to __init___ * Changing alias to * Set input mapping as input to ensure_tensor_on_device * Fix assertion placement * Add test_text_generation * Add TextGenerationPipeline to PipelineCommonTests * Take out whitespace * Format __init__ w black * Fix __init__ style * Forman __init___ * Add line to end of __init__ * Correct model tokenizer set for test_text_generation * Ensure to return list of list, not list of string (to pass test) * Limit test models to only 3 to limit runtime to address circleCI timeout error * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Update tests/test_pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Remove argument docstring, __init__, add additional __call__ arguments, and reformat results to list of dict * Fix blank result list * Add TextGenerationPipeline to pipelines.rst * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Fix typos from adding PADDING_TEXT_TOKEN_LENGTH * Fix incorrectly moved result list * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Add back generation line and make style * Take out blank whitespace * Apply new alis, text-generation, to test_pipelines * Fix text generation alias in test * Update src/transformers/pipelines.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
parent
1dc9b3c784
commit
f16540fcba
|
@ -66,3 +66,9 @@ SummarizationPipeline
|
|||
==========================================
|
||||
|
||||
.. autoclass:: transformers.SummarizationPipeline
|
||||
|
||||
|
||||
TextGenerationPipeline
|
||||
==========================================
|
||||
|
||||
.. autoclass:: transformers.TextGenerationPipeline
|
||||
|
|
|
@ -117,6 +117,7 @@ from .pipelines import (
|
|||
QuestionAnsweringPipeline,
|
||||
SummarizationPipeline,
|
||||
TextClassificationPipeline,
|
||||
TextGenerationPipeline,
|
||||
TokenClassificationPipeline,
|
||||
TranslationPipeline,
|
||||
pipeline,
|
||||
|
|
|
@ -520,6 +520,98 @@ class FeatureExtractionPipeline(Pipeline):
|
|||
return super().__call__(*args, **kwargs).tolist()
|
||||
|
||||
|
||||
class TextGenerationPipeline(Pipeline):
|
||||
"""
|
||||
Language generation pipeline using any ModelWithLMHead head. This pipeline predicts the words that will follow a specified text prompt.
|
||||
|
||||
This language generation pipeline can currently be loaded from the :func:`~transformers.pipeline` method using
|
||||
the following task identifier(s):
|
||||
|
||||
- "text-generation", for generating text from a specified prompt.
|
||||
|
||||
The models that this pipeline can use are models that have been trained with an autoregressive language modeling objective,
|
||||
which includes the uni-directional models in the library (e.g. gpt2).
|
||||
See the list of available community models on
|
||||
`huggingface.co/models <https://huggingface.co/models?search=&filter=lm-head>`__.
|
||||
"""
|
||||
|
||||
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
|
||||
# in https://github.com/rusiaaman/XLNet-gen#methodology
|
||||
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
|
||||
PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family
|
||||
(except for Alexei and Maria) are discovered.
|
||||
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
|
||||
remainder of the story. 1883 Western Siberia,
|
||||
a young Grigori Rasputin is asked by his father and a group of men to perform magic.
|
||||
Rasputin has a vision and denounces one of the men as a horse thief. Although his
|
||||
father initially slaps him for making such an accusation, Rasputin watches as the
|
||||
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
|
||||
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
|
||||
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
|
||||
|
||||
def __call__(
|
||||
self, *texts, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
|
||||
):
|
||||
text_inputs = self._args_parser(*texts)
|
||||
|
||||
results = []
|
||||
for prompt_text in text_inputs:
|
||||
# Manage correct placement of the tensors
|
||||
with self.device_placement():
|
||||
if self.model.__class__.__name__ in ["XLNetLMHeadModel", "TransfoXLLMHeadModel"]:
|
||||
inputs = self._parse_and_tokenize(self.PADDING_TEXT + prompt_text)
|
||||
else:
|
||||
inputs = self._parse_and_tokenize(prompt_text)
|
||||
|
||||
if self.framework == "pt":
|
||||
inputs = self.ensure_tensor_on_device(**inputs)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
|
||||
# Ensure that batch size = 1 (batch generation not allowed for now)
|
||||
assert (
|
||||
input_ids.shape[0] == 1
|
||||
), "Batch generation is currently not supported. See https://github.com/huggingface/transformers/issues/3021 for more information."
|
||||
|
||||
output_sequences = self.model.generate(input_ids=input_ids, **generate_kwargs) # BS x SL
|
||||
|
||||
result = []
|
||||
for generated_sequence in output_sequences:
|
||||
generated_sequence = generated_sequence.tolist()
|
||||
record = {}
|
||||
if return_tensors:
|
||||
record["generated_token_ids"] = generated_sequence
|
||||
if return_text:
|
||||
# Decode text
|
||||
text = self.tokenizer.decode(
|
||||
generated_sequence,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
)
|
||||
|
||||
# Remove PADDING prompt of the sequence if XLNet or Transfo-XL model is used
|
||||
record["generated_text"] = (
|
||||
prompt_text
|
||||
+ text[
|
||||
len(
|
||||
self.tokenizer.decode(
|
||||
input_ids[0],
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
)
|
||||
) :
|
||||
]
|
||||
)
|
||||
|
||||
result.append(record)
|
||||
results += [result]
|
||||
|
||||
if len(results) == 1:
|
||||
return results[0]
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class TextClassificationPipeline(Pipeline):
|
||||
"""
|
||||
Text classification pipeline using ModelForSequenceClassification head. See the
|
||||
|
@ -1456,6 +1548,12 @@ SUPPORTED_TASKS = {
|
|||
"tokenizer": ("t5-base", {"use_fast": False}),
|
||||
},
|
||||
},
|
||||
"text-generation": {
|
||||
"impl": TextGenerationPipeline,
|
||||
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
||||
"pt": AutoModelWithLMHead if is_torch_available() else None,
|
||||
"default": {"model": {"pt": "gpt2", "tf": "gpt2"}, "config": None, "tokenizer": "gpt2"},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -60,6 +60,11 @@ TEXT_CLASSIF_FINETUNED_MODELS = {
|
|||
)
|
||||
}
|
||||
|
||||
TEXT_GENERATION_FINETUNED_MODELS = {
|
||||
("gpt2", "gpt2"),
|
||||
("xlnet-base-cased", "xlnet-base-cased"),
|
||||
}
|
||||
|
||||
FILL_MASK_FINETUNED_MODELS = [
|
||||
(("distilroberta-base", {"use_fast": False}), "distilroberta-base", None),
|
||||
]
|
||||
|
@ -293,6 +298,16 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
|||
nlp, valid_inputs, invalid_inputs, mandatory_keys,
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_text_generation(self):
|
||||
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
|
||||
invalid_inputs = [None]
|
||||
for model, tokenizer in TEXT_GENERATION_FINETUNED_MODELS:
|
||||
nlp = pipeline(task="text-generation", model=model, tokenizer=tokenizer, framework="pt")
|
||||
self._test_mono_column_pipeline(
|
||||
nlp, valid_inputs, invalid_inputs, {},
|
||||
)
|
||||
|
||||
|
||||
class MultiColumnInputTestCase(unittest.TestCase):
|
||||
def _test_multicolumn_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]):
|
||||
|
@ -371,6 +386,7 @@ class PipelineCommonTests(unittest.TestCase):
|
|||
"translation_en_to_fr",
|
||||
"translation_en_to_de",
|
||||
"translation_en_to_ro",
|
||||
"text-generation",
|
||||
)
|
||||
|
||||
@slow
|
||||
|
|
Loading…
Reference in New Issue