Add chat support to text generation pipeline (#28945)
* Add chat support to text generation pipeline * Better handling of single elements * Deprecate ConversationalPipeline * stash commit * Add missing add_special_tokens kwarg * Update chat templating docs to refer to TextGenerationPipeline instead of ConversationalPipeline * Add ✨TF✨ tests * @require_tf * Add type hint * Add specific deprecation version * Remove unnecessary do_sample * Remove todo - the discrepancy has been resolved * Update src/transformers/tokenization_utils_base.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/pipelines/text_generation.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
636b03244c
commit
2f1003be86
|
@ -121,13 +121,15 @@ Arr, 'twas easy after all!
|
|||
|
||||
## Is there an automated pipeline for chat?
|
||||
|
||||
Yes, there is: [`ConversationalPipeline`]. This pipeline is designed to make it easy to use chat models. Let's try
|
||||
the `Zephyr` example again, but this time using the pipeline:
|
||||
Yes, there is! Our text generation pipelines support chat inputs, which makes it easy to use chat models. In the past,
|
||||
we used to use a dedicated "ConversationalPipeline" class, but this has now been deprecated and its functionality
|
||||
has been merged into the [`TextGenerationPipeline`]. Let's try the `Zephyr` example again, but this time using
|
||||
a pipeline:
|
||||
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
pipe = pipeline("conversational", "HuggingFaceH4/zephyr-7b-beta")
|
||||
pipe = pipeline("text-generation", "HuggingFaceH4/zephyr-7b-beta")
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
|
@ -135,17 +137,14 @@ messages = [
|
|||
},
|
||||
{"role": "user", "content": "How many helicopters can a human eat in one sitting?"},
|
||||
]
|
||||
print(pipe(messages))
|
||||
print(pipe(messages, max_new_tokens=128)[0]['generated_text'][-1]) # Print the assistant's response
|
||||
```
|
||||
|
||||
```text
|
||||
Conversation id: 76d886a0-74bd-454e-9804-0467041a63dc
|
||||
system: You are a friendly chatbot who always responds in the style of a pirate
|
||||
user: How many helicopters can a human eat in one sitting?
|
||||
assistant: Matey, I'm afraid I must inform ye that humans cannot eat helicopters. Helicopters are not food, they are flying machines. Food is meant to be eaten, like a hearty plate o' grog, a savory bowl o' stew, or a delicious loaf o' bread. But helicopters, they be for transportin' and movin' around, not for eatin'. So, I'd say none, me hearties. None at all.
|
||||
{'role': 'assistant', 'content': "Matey, I'm afraid I must inform ye that humans cannot eat helicopters. Helicopters are not food, they are flying machines. Food is meant to be eaten, like a hearty plate o' grog, a savory bowl o' stew, or a delicious loaf o' bread. But helicopters, they be for transportin' and movin' around, not for eatin'. So, I'd say none, me hearties. None at all."}
|
||||
```
|
||||
|
||||
[`ConversationalPipeline`] will take care of all the details of tokenization and calling `apply_chat_template` for you -
|
||||
The pipeline will take care of all the details of tokenization and calling `apply_chat_template` for you -
|
||||
once the model has a chat template, all you need to do is initialize the pipeline and pass it the list of messages!
|
||||
|
||||
## What are "generation prompts"?
|
||||
|
@ -191,7 +190,7 @@ Can I ask a question?<|im_end|>
|
|||
Note that this time, we've added the tokens that indicate the start of a bot response. This ensures that when the model
|
||||
generates text it will write a bot response instead of doing something unexpected, like continuing the user's
|
||||
message. Remember, chat models are still just language models - they're trained to continue text, and chat is just a
|
||||
special kind of text to them! You need to guide them with the appropriate control tokens so they know what they're
|
||||
special kind of text to them! You need to guide them with appropriate control tokens, so they know what they're
|
||||
supposed to be doing.
|
||||
|
||||
Not all models require generation prompts. Some models, like BlenderBot and LLaMA, don't have any
|
||||
|
@ -340,8 +339,8 @@ tokenizer.chat_template = template # Set the new template
|
|||
tokenizer.push_to_hub("model_name") # Upload your new template to the Hub!
|
||||
```
|
||||
|
||||
The method [`~PreTrainedTokenizer.apply_chat_template`] which uses your chat template is called by the [`ConversationalPipeline`] class, so
|
||||
once you set the correct chat template, your model will automatically become compatible with [`ConversationalPipeline`].
|
||||
The method [`~PreTrainedTokenizer.apply_chat_template`] which uses your chat template is called by the [`TextGenerationPipeline`] class, so
|
||||
once you set the correct chat template, your model will automatically become compatible with [`TextGenerationPipeline`].
|
||||
|
||||
<Tip>
|
||||
If you're fine-tuning a model for chat, in addition to setting a chat template, you should probably add any new chat
|
||||
|
@ -356,7 +355,7 @@ template. This will ensure that text generation tools can correctly figure out w
|
|||
|
||||
Before the introduction of chat templates, chat handling was hardcoded at the model class level. For backwards
|
||||
compatibility, we have retained this class-specific handling as default templates, also set at the class level. If a
|
||||
model does not have a chat template set, but there is a default template for its model class, the `ConversationalPipeline`
|
||||
model does not have a chat template set, but there is a default template for its model class, the `TextGenerationPipeline`
|
||||
class and methods like `apply_chat_template` will use the class template instead. You can find out what the default
|
||||
template for your tokenizer is by checking the `tokenizer.default_chat_template` attribute.
|
||||
|
||||
|
@ -407,7 +406,7 @@ I'm doing great!<|im_end|>
|
|||
```
|
||||
|
||||
The "user", "system" and "assistant" roles are the standard for chat, and we recommend using them when it makes sense,
|
||||
particularly if you want your model to operate well with [`ConversationalPipeline`]. However, you are not limited
|
||||
particularly if you want your model to operate well with [`TextGenerationPipeline`]. However, you are not limited
|
||||
to these roles - templating is extremely flexible, and any string can be a role.
|
||||
|
||||
### I want to add some chat templates! How should I get started?
|
||||
|
@ -418,7 +417,7 @@ not the model owner - if you're using a model with an empty chat template, or on
|
|||
template, please open a [pull request](https://huggingface.co/docs/hub/repositories-pull-requests-discussions) to the model repository so that this attribute can be set properly!
|
||||
|
||||
Once the attribute is set, that's it, you're done! `tokenizer.apply_chat_template` will now work correctly for that
|
||||
model, which means it is also automatically supported in places like `ConversationalPipeline`!
|
||||
model, which means it is also automatically supported in places like `TextGenerationPipeline`!
|
||||
|
||||
By ensuring that models have this attribute, we can make sure that the whole community gets to use the full power of
|
||||
open-source models. Formatting mismatches have been haunting the field and silently harming performance for too long -
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import uuid
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from ..utils import add_end_docstrings, is_tf_available, is_torch_available, logging
|
||||
|
@ -232,6 +233,10 @@ class ConversationalPipeline(Pipeline):
|
|||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
warnings.warn(
|
||||
"`ConversationalPipeline` is now deprecated, and the functionality has been moved to the standard `text-generation` pipeline, which now accepts lists of message dicts as well as strings. This class will be removed in v4.42.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
if self.tokenizer.pad_token_id is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import enum
|
||||
import warnings
|
||||
from typing import Dict
|
||||
|
||||
from ..utils import add_end_docstrings, is_tf_available, is_torch_available
|
||||
from .base import Pipeline, build_pipeline_init_args
|
||||
|
@ -20,11 +21,24 @@ class ReturnType(enum.Enum):
|
|||
FULL_TEXT = 2
|
||||
|
||||
|
||||
class Chat:
|
||||
"""This class is intended to just be used internally in this pipeline and not exposed to users. We convert chats
|
||||
to this format because the rest of the pipeline code tends to assume that lists of messages are
|
||||
actually a batch of samples rather than messages in the same conversation."""
|
||||
|
||||
def __init__(self, messages: Dict):
|
||||
for message in messages:
|
||||
if not ("role" in message and "content" in message):
|
||||
raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.")
|
||||
self.messages = messages
|
||||
|
||||
|
||||
@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True))
|
||||
class TextGenerationPipeline(Pipeline):
|
||||
"""
|
||||
Language generation pipeline using any `ModelWithLMHead`. This pipeline predicts the words that will follow a
|
||||
specified text prompt.
|
||||
specified text prompt. It can also accept one or more chats. Each chat takes the form of a list of dicts,
|
||||
where each dict contains "role" and "content" keys.
|
||||
|
||||
Example:
|
||||
|
||||
|
@ -216,7 +230,15 @@ class TextGenerationPipeline(Pipeline):
|
|||
- **generated_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token
|
||||
ids of the generated text.
|
||||
"""
|
||||
return super().__call__(text_inputs, **kwargs)
|
||||
if isinstance(text_inputs, (list, tuple)) and isinstance(text_inputs[0], (list, tuple, dict)):
|
||||
# We have one or more prompts in list-of-dicts format, so this is chat mode
|
||||
if isinstance(text_inputs[0], dict):
|
||||
return super().__call__(Chat(text_inputs), **kwargs)
|
||||
else:
|
||||
chats = [Chat(chat) for chat in text_inputs] # 🐈 🐈 🐈
|
||||
return super().__call__(chats, **kwargs)
|
||||
else:
|
||||
return super().__call__(text_inputs, **kwargs)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
|
@ -229,14 +251,25 @@ class TextGenerationPipeline(Pipeline):
|
|||
max_length=None,
|
||||
**generate_kwargs,
|
||||
):
|
||||
inputs = self.tokenizer(
|
||||
prefix + prompt_text,
|
||||
return_tensors=self.framework,
|
||||
truncation=truncation,
|
||||
padding=padding,
|
||||
max_length=max_length,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
if isinstance(prompt_text, Chat):
|
||||
inputs = self.tokenizer.apply_chat_template(
|
||||
prompt_text.messages,
|
||||
truncation=truncation,
|
||||
padding=padding,
|
||||
max_length=max_length,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
return_tensors=self.framework,
|
||||
)
|
||||
else:
|
||||
inputs = self.tokenizer(
|
||||
prefix + prompt_text,
|
||||
truncation=truncation,
|
||||
padding=padding,
|
||||
max_length=max_length,
|
||||
add_special_tokens=add_special_tokens,
|
||||
return_tensors=self.framework,
|
||||
)
|
||||
inputs["prompt_text"] = prompt_text
|
||||
|
||||
if handle_long_generation == "hole":
|
||||
|
@ -331,7 +364,10 @@ class TextGenerationPipeline(Pipeline):
|
|||
|
||||
all_text = text[prompt_length:]
|
||||
if return_type == ReturnType.FULL_TEXT:
|
||||
all_text = prompt_text + all_text
|
||||
if isinstance(prompt_text, str):
|
||||
all_text = prompt_text + all_text
|
||||
elif isinstance(prompt_text, Chat):
|
||||
all_text = prompt_text.messages + [{"role": "assistant", "content": all_text}]
|
||||
|
||||
record = {"generated_text": all_text}
|
||||
records.append(record)
|
||||
|
|
|
@ -1685,6 +1685,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||
truncation: bool = False,
|
||||
max_length: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
return_dict: bool = False,
|
||||
**tokenizer_kwargs,
|
||||
) -> Union[str, List[int]]:
|
||||
"""
|
||||
|
@ -1718,6 +1719,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
return_dict (`bool`, *optional*, defaults to `False`):
|
||||
Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
|
||||
**tokenizer_kwargs: Additional kwargs to pass to the tokenizer.
|
||||
|
||||
Returns:
|
||||
|
@ -1746,15 +1749,26 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||
if padding is True:
|
||||
padding = "max_length" # There's only one sequence here, so "longest" makes no sense
|
||||
if tokenize:
|
||||
return self.encode(
|
||||
rendered,
|
||||
add_special_tokens=False,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
return_tensors=return_tensors,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
if return_dict:
|
||||
return self(
|
||||
rendered,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
add_special_tokens=False,
|
||||
return_tensors=return_tensors,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
else:
|
||||
return self.encode(
|
||||
rendered,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
add_special_tokens=False,
|
||||
return_tensors=return_tensors,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
else:
|
||||
return rendered
|
||||
|
||||
|
|
|
@ -131,6 +131,52 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
|||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_small_chat_model_pt(self):
|
||||
text_generator = pipeline(
|
||||
task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="pt"
|
||||
)
|
||||
# Using `do_sample=False` to force deterministic output
|
||||
chat1 = [
|
||||
{"role": "system", "content": "This is a system message."},
|
||||
{"role": "user", "content": "This is a test"},
|
||||
{"role": "assistant", "content": "This is a reply"},
|
||||
]
|
||||
chat2 = [
|
||||
{"role": "system", "content": "This is a system message."},
|
||||
{"role": "user", "content": "This is a second test"},
|
||||
{"role": "assistant", "content": "This is a reply"},
|
||||
]
|
||||
outputs = text_generator(chat1, do_sample=False, max_new_tokens=10)
|
||||
expected_chat1 = chat1 + [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": " factors factors factors factors factors factors factors factors factors factors",
|
||||
}
|
||||
]
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
{"generated_text": expected_chat1},
|
||||
],
|
||||
)
|
||||
|
||||
outputs = text_generator([chat1, chat2], do_sample=False, max_new_tokens=10)
|
||||
expected_chat2 = chat2 + [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": " factors factors factors factors factors factors factors factors factors factors",
|
||||
}
|
||||
]
|
||||
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
[{"generated_text": expected_chat1}],
|
||||
[{"generated_text": expected_chat2}],
|
||||
],
|
||||
)
|
||||
|
||||
@require_tf
|
||||
def test_small_model_tf(self):
|
||||
text_generator = pipeline(task="text-generation", model="sshleifer/tiny-ctrl", framework="tf")
|
||||
|
@ -172,6 +218,52 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
|||
],
|
||||
)
|
||||
|
||||
@require_tf
|
||||
def test_small_chat_model_tf(self):
|
||||
text_generator = pipeline(
|
||||
task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="tf"
|
||||
)
|
||||
# Using `do_sample=False` to force deterministic output
|
||||
chat1 = [
|
||||
{"role": "system", "content": "This is a system message."},
|
||||
{"role": "user", "content": "This is a test"},
|
||||
{"role": "assistant", "content": "This is a reply"},
|
||||
]
|
||||
chat2 = [
|
||||
{"role": "system", "content": "This is a system message."},
|
||||
{"role": "user", "content": "This is a second test"},
|
||||
{"role": "assistant", "content": "This is a reply"},
|
||||
]
|
||||
outputs = text_generator(chat1, do_sample=False, max_new_tokens=10)
|
||||
expected_chat1 = chat1 + [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": " factors factors factors factors factors factors factors factors factors factors",
|
||||
}
|
||||
]
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
{"generated_text": expected_chat1},
|
||||
],
|
||||
)
|
||||
|
||||
outputs = text_generator([chat1, chat2], do_sample=False, max_new_tokens=10)
|
||||
expected_chat2 = chat2 + [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": " factors factors factors factors factors factors factors factors factors factors",
|
||||
}
|
||||
]
|
||||
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
[{"generated_text": expected_chat1}],
|
||||
[{"generated_text": expected_chat2}],
|
||||
],
|
||||
)
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor):
|
||||
text_generator = TextGenerationPipeline(model=model, tokenizer=tokenizer)
|
||||
return text_generator, ["This is a test", "Another test"]
|
||||
|
|
Loading…
Reference in New Issue