Pipeline for Text Generation: GenerationPipeline (#3758)

* Add GenerationPipeline

* Fix parameter names

* Correct parameter __call__ parameters

* Add model type attribute and correct function calls for prepare_input

* Take out trailing commas from init attributes

* Remove unnecessary tokenization line

* Implement support for multiple text inputs

* Apply generation support for multiple input text prompts

* Take out tensor coersion

* Take out batch index

* Add text prompt to return sequence

* Squeeze token tensore before decoding

* Return only a single list of sequences if only one prompt was used

* Correct results variable name

* Add GenerationPipeline to SUPPORTED_TASKS with the alias , initalized w GPT2

* Registedred AutoModelWithLMHead for both pt and t

* Update docstring for GenerationPipeline

* Add kwargs parameter to mode.generate

* Take out kwargs parameter after all

* Add generation pipeline example in pipeline docstring

* Fix max length by squeezing tokens tensor

* Apply ensure_tensor_on_device to pytorch tensor

* Include generation step in torch.no_grad

* Take out input from prepare_xlm_input and set 'en' as default xlm_language

* Apply framework specific encoding during prepare_input

* Format w make style

* Move GenerationPipeline import to follow proper import sorting

* Take out training comma from generation dict

* Apply requested changes

* Change name to TextGenerationPipeline

* Apply TextGenerationPipeline rename to __init___

* Changing alias to

* Set input mapping as input to ensure_tensor_on_device

* Fix assertion placement

* Add test_text_generation

* Add TextGenerationPipeline to PipelineCommonTests

* Take out whitespace

* Format __init__ w black

* Fix __init__ style

* Forman __init___

* Add line to end of __init__

* Correct model tokenizer set for test_text_generation

* Ensure to return list of list, not list of string (to pass test)

* Limit test models to only 3 to limit runtime to address circleCI timeout error

* Update src/transformers/pipelines.py

Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines.py

Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines.py

Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines.py

Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines.py

Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com>

* Update tests/test_pipelines.py

Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines.py

Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines.py

Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines.py

Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com>

* Remove argument docstring, __init__, add additional __call__ arguments, and reformat results to list of dict

* Fix blank result list

* Add TextGenerationPipeline to pipelines.rst

* Update src/transformers/pipelines.py

Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines.py

Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com>

* Fix typos from adding PADDING_TEXT_TOKEN_LENGTH

* Fix incorrectly moved result list

* Update src/transformers/pipelines.py

Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines.py

* Update src/transformers/pipelines.py

* Update src/transformers/pipelines.py

* Update src/transformers/pipelines.py

* Update src/transformers/pipelines.py

* Update src/transformers/pipelines.py

* Update src/transformers/pipelines.py

* Update src/transformers/pipelines.py

* Update src/transformers/pipelines.py

* Update src/transformers/pipelines.py

* Update src/transformers/pipelines.py

* Update src/transformers/pipelines.py

Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com>

* Add back generation line and make style

* Take out blank whitespace

* Apply new alis, text-generation, to test_pipelines

* Fix text generation alias in test

* Update src/transformers/pipelines.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
Lorenzo Ampil 2020-04-22 21:37:03 +08:00 committed by GitHub
parent 1dc9b3c784
commit f16540fcba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 121 additions and 0 deletions

View File

@ -66,3 +66,9 @@ SummarizationPipeline
==========================================
.. autoclass:: transformers.SummarizationPipeline
TextGenerationPipeline
==========================================
.. autoclass:: transformers.TextGenerationPipeline

View File

@ -117,6 +117,7 @@ from .pipelines import (
QuestionAnsweringPipeline,
SummarizationPipeline,
TextClassificationPipeline,
TextGenerationPipeline,
TokenClassificationPipeline,
TranslationPipeline,
pipeline,

View File

@ -520,6 +520,98 @@ class FeatureExtractionPipeline(Pipeline):
return super().__call__(*args, **kwargs).tolist()
class TextGenerationPipeline(Pipeline):
"""
Language generation pipeline using any ModelWithLMHead head. This pipeline predicts the words that will follow a specified text prompt.
This language generation pipeline can currently be loaded from the :func:`~transformers.pipeline` method using
the following task identifier(s):
- "text-generation", for generating text from a specified prompt.
The models that this pipeline can use are models that have been trained with an autoregressive language modeling objective,
which includes the uni-directional models in the library (e.g. gpt2).
See the list of available community models on
`huggingface.co/models <https://huggingface.co/models?search=&filter=lm-head>`__.
"""
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# in https://github.com/rusiaaman/XLNet-gen#methodology
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
remainder of the story. 1883 Western Siberia,
a young Grigori Rasputin is asked by his father and a group of men to perform magic.
Rasputin has a vision and denounces one of the men as a horse thief. Although his
father initially slaps him for making such an accusation, Rasputin watches as the
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
def __call__(
self, *texts, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
):
text_inputs = self._args_parser(*texts)
results = []
for prompt_text in text_inputs:
# Manage correct placement of the tensors
with self.device_placement():
if self.model.__class__.__name__ in ["XLNetLMHeadModel", "TransfoXLLMHeadModel"]:
inputs = self._parse_and_tokenize(self.PADDING_TEXT + prompt_text)
else:
inputs = self._parse_and_tokenize(prompt_text)
if self.framework == "pt":
inputs = self.ensure_tensor_on_device(**inputs)
input_ids = inputs["input_ids"]
# Ensure that batch size = 1 (batch generation not allowed for now)
assert (
input_ids.shape[0] == 1
), "Batch generation is currently not supported. See https://github.com/huggingface/transformers/issues/3021 for more information."
output_sequences = self.model.generate(input_ids=input_ids, **generate_kwargs) # BS x SL
result = []
for generated_sequence in output_sequences:
generated_sequence = generated_sequence.tolist()
record = {}
if return_tensors:
record["generated_token_ids"] = generated_sequence
if return_text:
# Decode text
text = self.tokenizer.decode(
generated_sequence,
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
# Remove PADDING prompt of the sequence if XLNet or Transfo-XL model is used
record["generated_text"] = (
prompt_text
+ text[
len(
self.tokenizer.decode(
input_ids[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
) :
]
)
result.append(record)
results += [result]
if len(results) == 1:
return results[0]
return results
class TextClassificationPipeline(Pipeline):
"""
Text classification pipeline using ModelForSequenceClassification head. See the
@ -1456,6 +1548,12 @@ SUPPORTED_TASKS = {
"tokenizer": ("t5-base", {"use_fast": False}),
},
},
"text-generation": {
"impl": TextGenerationPipeline,
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
"pt": AutoModelWithLMHead if is_torch_available() else None,
"default": {"model": {"pt": "gpt2", "tf": "gpt2"}, "config": None, "tokenizer": "gpt2"},
},
}

View File

@ -60,6 +60,11 @@ TEXT_CLASSIF_FINETUNED_MODELS = {
)
}
TEXT_GENERATION_FINETUNED_MODELS = {
("gpt2", "gpt2"),
("xlnet-base-cased", "xlnet-base-cased"),
}
FILL_MASK_FINETUNED_MODELS = [
(("distilroberta-base", {"use_fast": False}), "distilroberta-base", None),
]
@ -293,6 +298,16 @@ class MonoColumnInputTestCase(unittest.TestCase):
nlp, valid_inputs, invalid_inputs, mandatory_keys,
)
@require_torch
def test_text_generation(self):
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
invalid_inputs = [None]
for model, tokenizer in TEXT_GENERATION_FINETUNED_MODELS:
nlp = pipeline(task="text-generation", model=model, tokenizer=tokenizer, framework="pt")
self._test_mono_column_pipeline(
nlp, valid_inputs, invalid_inputs, {},
)
class MultiColumnInputTestCase(unittest.TestCase):
def _test_multicolumn_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]):
@ -371,6 +386,7 @@ class PipelineCommonTests(unittest.TestCase):
"translation_en_to_fr",
"translation_en_to_de",
"translation_en_to_ro",
"text-generation",
)
@slow