Compare commits
5 Commits
main
...
remove_con
Author | SHA1 | Date |
---|---|---|
Matt | 427f23abfb | |
Matt | 51a4f8d014 | |
Matt | b78851e9bd | |
Matt | b53ebaca35 | |
Matt | c9e4f75f9b |
|
@ -386,14 +386,6 @@ Pipelines available for computer vision tasks include the following.
|
|||
|
||||
Pipelines available for natural language processing tasks include the following.
|
||||
|
||||
### ConversationalPipeline
|
||||
|
||||
[[autodoc]] Conversation
|
||||
|
||||
[[autodoc]] ConversationalPipeline
|
||||
- __call__
|
||||
- all
|
||||
|
||||
### FillMaskPipeline
|
||||
|
||||
[[autodoc]] FillMaskPipeline
|
||||
|
|
|
@ -180,8 +180,8 @@ tokenizer.chat_template = template # Set the new template
|
|||
tokenizer.push_to_hub("model_name") # Upload your new template to the Hub!
|
||||
```
|
||||
|
||||
[`~PreTrainedTokenizer.apply_chat_template`] メソッドは、あなたのチャットテンプレートを使用するために [`ConversationalPipeline`] クラスによって呼び出されます。
|
||||
したがって、正しいチャットテンプレートを設定すると、あなたのモデルは自動的に [`ConversationalPipeline`] と互換性があるようになります。
|
||||
[`~PreTrainedTokenizer.apply_chat_template`] メソッドは、あなたのチャットテンプレートを使用するために `TextGenerationPipeline` クラスによって呼び出されます。
|
||||
したがって、正しいチャットテンプレートを設定すると、あなたのモデルは自動的に [`TextGenerationPipeline`] と互換性があるようになります。
|
||||
|
||||
|
||||
## What are "default" templates?
|
||||
|
@ -189,7 +189,7 @@ tokenizer.push_to_hub("model_name") # Upload your new template to the Hub!
|
|||
チャットテンプレートの導入前に、チャットの処理はモデルクラスレベルでハードコードされていました。
|
||||
後方互換性のために、このクラス固有の処理をデフォルトテンプレートとして保持し、クラスレベルで設定されています。
|
||||
モデルにチャットテンプレートが設定されていない場合、ただしモデルクラスのデフォルトテンプレートがある場合、
|
||||
`ConversationalPipeline`クラスや`apply_chat_template`などのメソッドはクラステンプレートを使用します。
|
||||
`TextGenerationPipeline`クラスや`apply_chat_template`などのメソッドはクラステンプレートを使用します。
|
||||
トークナイザのデフォルトのチャットテンプレートを確認するには、`tokenizer.default_chat_template`属性をチェックしてください。
|
||||
|
||||
これは、後方互換性のために純粋に行っていることで、既存のワークフローを壊さないようにしています。
|
||||
|
@ -233,7 +233,7 @@ I'm doing great!<|im_end|>
|
|||
```
|
||||
|
||||
「ユーザー」、「システム」、および「アシスタント」の役割は、チャットの標準です。
|
||||
特に、[`ConversationalPipeline`]との連携をスムーズに行う場合には、これらの役割を使用することをお勧めします。ただし、これらの役割に制約はありません。テンプレートは非常に柔軟で、任意の文字列を役割として使用できます。
|
||||
特に、`TextGenerationPipeline`との連携をスムーズに行う場合には、これらの役割を使用することをお勧めします。ただし、これらの役割に制約はありません。テンプレートは非常に柔軟で、任意の文字列を役割として使用できます。
|
||||
|
||||
## I want to use chat templates! How should I get started?
|
||||
|
||||
|
@ -242,7 +242,7 @@ I'm doing great!<|im_end|>
|
|||
この属性を適切に設定できるように[プルリクエスト](https://huggingface.co/docs/hub/repositories-pull-requests-discussions)を開いてください。
|
||||
|
||||
一度属性が設定されれば、それで完了です! `tokenizer.apply_chat_template`は、そのモデルに対して正しく動作するようになります。これは、
|
||||
`ConversationalPipeline`などの場所でも自動的にサポートされます。
|
||||
`TextGenerationPipeline` などの場所でも自動的にサポートされます。
|
||||
|
||||
モデルがこの属性を持つことを確認することで、オープンソースモデルの全コミュニティがそのフルパワーを使用できるようになります。
|
||||
フォーマットの不一致はこの分野に悩み続け、パフォーマンスに黙って影響を与えてきました。それを終わらせる時が来ました!
|
||||
|
|
|
@ -388,14 +388,6 @@ my_pipeline = pipeline(model="xxxx", pipeline_class=MyPipeline)
|
|||
|
||||
自然言語処理タスクに使用できるパイプラインには次のものがあります。
|
||||
|
||||
### ConversationalPipeline
|
||||
|
||||
[[autodoc]] Conversation
|
||||
|
||||
[[autodoc]] ConversationalPipeline
|
||||
- __call__
|
||||
- all
|
||||
|
||||
### FillMaskPipeline
|
||||
|
||||
[[autodoc]] FillMaskPipeline
|
||||
|
|
|
@ -117,12 +117,12 @@ Matey, I'm afraid I must inform ye that humans cannot eat helicopters. Helicopte
|
|||
|
||||
## 有自动化的聊天`pipeline`吗?
|
||||
|
||||
有的,[`ConversationalPipeline`]。这个`pipeline`的设计是为了方便使用聊天模型。让我们再试一次 Zephyr 的例子,但这次使用`pipeline`:
|
||||
有的,[`TextGenerationPipeline`]。这个`pipeline`的设计是为了方便使用聊天模型。让我们再试一次 Zephyr 的例子,但这次使用`pipeline`:
|
||||
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
pipe = pipeline("conversational", "HuggingFaceH4/zephyr-7b-beta")
|
||||
pipe = pipeline("text-generation", "HuggingFaceH4/zephyr-7b-beta")
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
|
@ -130,17 +130,14 @@ messages = [
|
|||
},
|
||||
{"role": "user", "content": "How many helicopters can a human eat in one sitting?"},
|
||||
]
|
||||
print(pipe(messages))
|
||||
print(pipe(messages, max_new_tokens=256)['generated_text'][-1])
|
||||
```
|
||||
|
||||
```text
|
||||
Conversation id: 76d886a0-74bd-454e-9804-0467041a63dc
|
||||
system: You are a friendly chatbot who always responds in the style of a pirate
|
||||
user: How many helicopters can a human eat in one sitting?
|
||||
assistant: Matey, I'm afraid I must inform ye that humans cannot eat helicopters. Helicopters are not food, they are flying machines. Food is meant to be eaten, like a hearty plate o' grog, a savory bowl o' stew, or a delicious loaf o' bread. But helicopters, they be for transportin' and movin' around, not for eatin'. So, I'd say none, me hearties. None at all.
|
||||
{'role': 'assistant', 'content': "Matey, I'm afraid I must inform ye that humans cannot eat helicopters. Helicopters are not food, they are flying machines. Food is meant to be eaten, like a hearty plate o' grog, a savory bowl o' stew, or a delicious loaf o' bread. But helicopters, they be for transportin' and movin' around, not for eatin'. So, I'd say none, me hearties. None at all."}
|
||||
```
|
||||
|
||||
[`ConversationalPipeline`]将负责处理所有的`tokenized`并调用`apply_chat_template`,一旦模型有了聊天模板,您只需要初始化pipeline并传递消息列表!
|
||||
[`TextGenerationPipeline`]将负责处理所有的`tokenized`并调用`apply_chat_template`,一旦模型有了聊天模板,您只需要初始化pipeline并传递消息列表!
|
||||
|
||||
## 什么是"generation prompts"?
|
||||
|
||||
|
@ -317,12 +314,12 @@ tokenizer.chat_template = template # Set the new template
|
|||
tokenizer.push_to_hub("model_name") # Upload your new template to the Hub!
|
||||
```
|
||||
|
||||
由于[`~PreTrainedTokenizer.apply_chat_template`]方法是由[`ConversationalPipeline`]类调用,
|
||||
因此一旦你设置了聊天模板,您的模型将自动与[`ConversationalPipeline`]兼容。
|
||||
由于[`~PreTrainedTokenizer.apply_chat_template`]方法是由[`TextGenerationPipeline`]类调用,
|
||||
因此一旦你设置了聊天模板,您的模型将自动与[`TextGenerationPipeline`]兼容。
|
||||
### “默认”模板是什么?
|
||||
|
||||
在引入聊天模板(chat_template)之前,聊天prompt是在模型中通过硬编码处理的。为了向前兼容,我们保留了这种硬编码处理聊天prompt的方法。
|
||||
如果一个模型没有设置聊天模板,但其模型有默认模板,`ConversationalPipeline`类和`apply_chat_template`等方法将使用该模型的聊天模板。
|
||||
如果一个模型没有设置聊天模板,但其模型有默认模板,`TextGenerationPipeline`类和`apply_chat_template`等方法将使用该模型的聊天模板。
|
||||
您可以通过检查`tokenizer.default_chat_template`属性来查找`tokenizer`的默认模板。
|
||||
|
||||
这是我们纯粹为了向前兼容性而做的事情,以避免破坏任何现有的工作流程。即使默认的聊天模板适用于您的模型,
|
||||
|
@ -367,7 +364,7 @@ How are you?<|im_end|>
|
|||
I'm doing great!<|im_end|>
|
||||
```
|
||||
|
||||
`user`,`system`和`assistant`是对话助手模型的标准角色,如果您的模型要与[`ConversationalPipeline`]兼容,我们建议你使用这些角色。
|
||||
`user`,`system`和`assistant`是对话助手模型的标准角色,如果您的模型要与[`TextGenerationPipeline`]兼容,我们建议你使用这些角色。
|
||||
但您可以不局限于这些角色,模板非常灵活,任何字符串都可以成为角色。
|
||||
|
||||
### 如何添加聊天模板?
|
||||
|
@ -378,7 +375,7 @@ I'm doing great!<|im_end|>
|
|||
请发起一个[pull request](https://huggingface.co/docs/hub/repositories-pull-requests-discussions),以便正确设置该属性!
|
||||
|
||||
一旦属性设置完成,就完成了!`tokenizer.apply_chat_template`现在将在该模型中正常工作,
|
||||
这意味着它也会自动支持在诸如`ConversationalPipeline`的地方!
|
||||
这意味着它也会自动支持在诸如`TextGenerationPipeline`的地方!
|
||||
|
||||
通过确保模型具有这一属性,我们可以确保整个社区都能充分利用开源模型的全部功能。
|
||||
格式不匹配已经困扰这个领域并悄悄地损害了性能太久了,是时候结束它们了!
|
||||
|
|
|
@ -362,14 +362,6 @@ my_pipeline = pipeline(model="xxxx", pipeline_class=MyPipeline)
|
|||
|
||||
可用于自然语言处理任务的pipeline包括以下几种。
|
||||
|
||||
### ConversationalPipeline
|
||||
|
||||
[[autodoc]] Conversation
|
||||
|
||||
[[autodoc]] ConversationalPipeline
|
||||
- __call__
|
||||
- all
|
||||
|
||||
### FillMaskPipeline
|
||||
|
||||
[[autodoc]] FillMaskPipeline
|
||||
|
|
|
@ -799,8 +799,6 @@ _import_structure = {
|
|||
"pipelines": [
|
||||
"AudioClassificationPipeline",
|
||||
"AutomaticSpeechRecognitionPipeline",
|
||||
"Conversation",
|
||||
"ConversationalPipeline",
|
||||
"CsvPipelineDataFormat",
|
||||
"DepthEstimationPipeline",
|
||||
"DocumentQuestionAnsweringPipeline",
|
||||
|
@ -5428,8 +5426,6 @@ if TYPE_CHECKING:
|
|||
from .pipelines import (
|
||||
AudioClassificationPipeline,
|
||||
AutomaticSpeechRecognitionPipeline,
|
||||
Conversation,
|
||||
ConversationalPipeline,
|
||||
CsvPipelineDataFormat,
|
||||
DepthEstimationPipeline,
|
||||
DocumentQuestionAnsweringPipeline,
|
||||
|
|
|
@ -20,7 +20,6 @@ from typing import Dict, List, Literal, Union
|
|||
|
||||
from tokenizers import processors
|
||||
|
||||
from ...pipelines.conversational import Conversation
|
||||
from ...tokenization_utils_base import BatchEncoding
|
||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from ...utils import logging
|
||||
|
@ -413,7 +412,7 @@ class CohereTokenizerFast(PreTrainedTokenizerFast):
|
|||
|
||||
def apply_tool_use_template(
|
||||
self,
|
||||
conversation: Union[List[Dict[str, str]], "Conversation"],
|
||||
conversation: Union[List[Dict[str, str]]],
|
||||
tools: List[Dict],
|
||||
**kwargs,
|
||||
) -> Union[str, List[int]]:
|
||||
|
@ -424,13 +423,13 @@ class CohereTokenizerFast(PreTrainedTokenizerFast):
|
|||
|
||||
Conceptually, this works in the same way as `apply_chat_format`, but takes an additional `tools` parameter.
|
||||
|
||||
Converts a Conversation object or a list of dictionaries with `"role"` and `"content"` keys and a list of available
|
||||
Converts a chat in the form of a list of dictionaries with `"role"` and `"content"` keys and a list of available
|
||||
tools for the model to use into a prompt string, or a list of token ids.
|
||||
This method will use the tokenizer's `default_tool_use_template` template specified at the class level.
|
||||
You can override the default template using the `tool_use_template` kwarg but the quality of your results may decrease.
|
||||
|
||||
Args:
|
||||
conversation (Union[List[Dict[str, str]], "Conversation"]): A Conversation object or list of dicts
|
||||
conversation (Union[List[Dict[str, str]]]): A list of dicts
|
||||
with "role" and "content" keys, representing the chat history so far.
|
||||
tools (List[Dict]): a list of tools to render into the prompt for the model to choose from.
|
||||
See an example at the bottom of the docstring.
|
||||
|
@ -568,7 +567,7 @@ class CohereTokenizerFast(PreTrainedTokenizerFast):
|
|||
|
||||
def apply_grounded_generation_template(
|
||||
self,
|
||||
conversation: Union[List[Dict[str, str]], "Conversation"],
|
||||
conversation: Union[List[Dict[str, str]]],
|
||||
documents: List[Dict],
|
||||
citation_mode: Literal["fast", "accurate"] = "accurate",
|
||||
**kwargs,
|
||||
|
@ -580,13 +579,13 @@ class CohereTokenizerFast(PreTrainedTokenizerFast):
|
|||
Conceptually, this works in the same way as `apply_chat_format`, but takes additional `documents`
|
||||
and parameter `citation_mode` parameters.
|
||||
|
||||
Converts a Conversation object or a list of dictionaries with `"role"` and `"content"` keys and a list of
|
||||
Converts a list of dictionaries with `"role"` and `"content"` keys and a list of
|
||||
documents for the model to ground its response on into a prompt string, or a list of token ids.
|
||||
This method will use the tokenizer's `grounded_generation_template` template specified at the class level.
|
||||
You can override the default template using the `grounded_generation_template` kwarg but the quality of your results may decrease.
|
||||
|
||||
Args:
|
||||
conversation (Union[List[Dict[str, str]], "Conversation"]): A Conversation object or list of dicts
|
||||
conversation (Union[List[Dict[str, str]]]): A list of dicts
|
||||
with "role" and "content" keys, representing the chat history so far.
|
||||
documents (List[Dict[str, str]): A list of dicts, representing documents or tool outputs to ground your
|
||||
generation on. A document is a semistructured dict, wiht a string to string mapping. Common fields are
|
||||
|
|
|
@ -26,7 +26,6 @@ from ...utils import TensorType, logging
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...pipelines.conversational import Conversation
|
||||
from ...tokenization_utils_base import PreTokenizedInput
|
||||
|
||||
|
||||
|
@ -255,7 +254,7 @@ class Idefics2Processor(ProcessorMixin):
|
|||
|
||||
def apply_chat_template(
|
||||
self,
|
||||
conversation: Union[List[Dict[str, str]], "Conversation"],
|
||||
conversation: Union[List[Dict[str, str]]],
|
||||
chat_template: Optional[str] = None,
|
||||
tokenize: bool = False,
|
||||
**kwargs,
|
||||
|
@ -269,7 +268,7 @@ class Idefics2Processor(ProcessorMixin):
|
|||
tokens to the sequence length or adding the surrounding tokens e.g. <fake_image_token>.
|
||||
|
||||
Args:
|
||||
conversation (`Union[List[Dict, str, str], "Conversation"]`):
|
||||
conversation (`Union[List[Dict, str, str]]`):
|
||||
The conversation to format.
|
||||
chat_template (`Optional[str]`, *optional*):
|
||||
The Jinja template to use for formatting the conversation. If not provided, the default chat template
|
||||
|
|
|
@ -58,7 +58,6 @@ from .base import (
|
|||
get_default_model_and_revision,
|
||||
infer_framework_load_model,
|
||||
)
|
||||
from .conversational import Conversation, ConversationalPipeline
|
||||
from .depth_estimation import DepthEstimationPipeline
|
||||
from .document_question_answering import DocumentQuestionAnsweringPipeline
|
||||
from .feature_extraction import FeatureExtractionPipeline
|
||||
|
@ -340,15 +339,6 @@ SUPPORTED_TASKS = {
|
|||
},
|
||||
"type": "multimodal",
|
||||
},
|
||||
"conversational": {
|
||||
"impl": ConversationalPipeline,
|
||||
"tf": (TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM) if is_tf_available() else (),
|
||||
"pt": (AutoModelForSeq2SeqLM, AutoModelForCausalLM) if is_torch_available() else (),
|
||||
"default": {
|
||||
"model": {"pt": ("microsoft/DialoGPT-medium", "8bada3b"), "tf": ("microsoft/DialoGPT-medium", "8bada3b")}
|
||||
},
|
||||
"type": "text",
|
||||
},
|
||||
"image-classification": {
|
||||
"impl": ImageClassificationPipeline,
|
||||
"tf": (TFAutoModelForImageClassification,) if is_tf_available() else (),
|
||||
|
@ -593,7 +583,6 @@ def pipeline(
|
|||
|
||||
- `"audio-classification"`: will return a [`AudioClassificationPipeline`].
|
||||
- `"automatic-speech-recognition"`: will return a [`AutomaticSpeechRecognitionPipeline`].
|
||||
- `"conversational"`: will return a [`ConversationalPipeline`].
|
||||
- `"depth-estimation"`: will return a [`DepthEstimationPipeline`].
|
||||
- `"document-question-answering"`: will return a [`DocumentQuestionAnsweringPipeline`].
|
||||
- `"feature-extraction"`: will return a [`FeatureExtractionPipeline`].
|
||||
|
|
|
@ -1,322 +0,0 @@
|
|||
import uuid
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from ..utils import add_end_docstrings, is_tf_available, is_torch_available, logging
|
||||
from .base import Pipeline, build_pipeline_init_args
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Conversation:
|
||||
"""
|
||||
Utility class containing a conversation and its history. This class is meant to be used as an input to the
|
||||
[`ConversationalPipeline`]. The conversation contains several utility functions to manage the addition of new user
|
||||
inputs and generated model responses.
|
||||
|
||||
Arguments:
|
||||
messages (Union[str, List[Dict[str, str]]], *optional*):
|
||||
The initial messages to start the conversation, either a string, or a list of dicts containing "role" and
|
||||
"content" keys. If a string is passed, it is interpreted as a single message with the "user" role.
|
||||
conversation_id (`uuid.UUID`, *optional*):
|
||||
Unique identifier for the conversation. If not provided, a random UUID4 id will be assigned to the
|
||||
conversation.
|
||||
|
||||
Usage:
|
||||
|
||||
```python
|
||||
conversation = Conversation("Going to the movies tonight - any suggestions?")
|
||||
conversation.add_message({"role": "assistant", "content": "The Big lebowski."})
|
||||
conversation.add_message({"role": "user", "content": "Is it good?"})
|
||||
```"""
|
||||
|
||||
def __init__(
|
||||
self, messages: Union[str, List[Dict[str, str]]] = None, conversation_id: uuid.UUID = None, **deprecated_kwargs
|
||||
):
|
||||
if not conversation_id:
|
||||
conversation_id = uuid.uuid4()
|
||||
|
||||
if messages is None:
|
||||
text = deprecated_kwargs.pop("text", None)
|
||||
if text is not None:
|
||||
messages = [{"role": "user", "content": text}]
|
||||
else:
|
||||
messages = []
|
||||
elif isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
# This block deals with the legacy args - new code should just totally
|
||||
# avoid past_user_inputs and generated_responses
|
||||
self._num_processed_user_inputs = 0
|
||||
generated_responses = deprecated_kwargs.pop("generated_responses", None)
|
||||
past_user_inputs = deprecated_kwargs.pop("past_user_inputs", None)
|
||||
if generated_responses is not None and past_user_inputs is None:
|
||||
raise ValueError("generated_responses cannot be passed without past_user_inputs!")
|
||||
if past_user_inputs is not None:
|
||||
legacy_messages = []
|
||||
if generated_responses is None:
|
||||
generated_responses = []
|
||||
# We structure it this way instead of using zip() because the lengths may differ by 1
|
||||
for i in range(max([len(past_user_inputs), len(generated_responses)])):
|
||||
if i < len(past_user_inputs):
|
||||
legacy_messages.append({"role": "user", "content": past_user_inputs[i]})
|
||||
if i < len(generated_responses):
|
||||
legacy_messages.append({"role": "assistant", "content": generated_responses[i]})
|
||||
messages = legacy_messages + messages
|
||||
|
||||
self.uuid = conversation_id
|
||||
self.messages = messages
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Conversation):
|
||||
return False
|
||||
return self.uuid == other.uuid or self.messages == other.messages
|
||||
|
||||
def add_message(self, message: Dict[str, str]):
|
||||
if not set(message.keys()) == {"role", "content"}:
|
||||
raise ValueError("Message should contain only 'role' and 'content' keys!")
|
||||
if message["role"] not in ("user", "assistant", "system"):
|
||||
raise ValueError("Only 'user', 'assistant' and 'system' roles are supported for now!")
|
||||
self.messages.append(message)
|
||||
|
||||
def add_user_input(self, text: str, overwrite: bool = False):
|
||||
"""
|
||||
Add a user input to the conversation for the next round. This is a legacy method that assumes that inputs must
|
||||
alternate user/assistant/user/assistant, and so will not add multiple user messages in succession. We recommend
|
||||
just using `add_message` with role "user" instead.
|
||||
"""
|
||||
if len(self) > 0 and self[-1]["role"] == "user":
|
||||
if overwrite:
|
||||
logger.warning(
|
||||
f'User input added while unprocessed input was existing: "{self[-1]["content"]}" was overwritten '
|
||||
f'with: "{text}".'
|
||||
)
|
||||
self[-1]["content"] = text
|
||||
else:
|
||||
logger.warning(
|
||||
f'User input added while unprocessed input was existing: "{self[-1]["content"]}" new input '
|
||||
f'ignored: "{text}". Set `overwrite` to True to overwrite unprocessed user input'
|
||||
)
|
||||
else:
|
||||
self.messages.append({"role": "user", "content": text})
|
||||
|
||||
def append_response(self, response: str):
|
||||
"""
|
||||
This is a legacy method. We recommend just using `add_message` with an appropriate role instead.
|
||||
"""
|
||||
self.messages.append({"role": "assistant", "content": response})
|
||||
|
||||
def mark_processed(self):
|
||||
"""
|
||||
This is a legacy method, as the Conversation no longer distinguishes between processed and unprocessed user
|
||||
input. We set a counter here to keep behaviour mostly backward-compatible, but in general you should just read
|
||||
the messages directly when writing new code.
|
||||
"""
|
||||
self._num_processed_user_inputs = len(self._user_messages)
|
||||
|
||||
def __iter__(self):
|
||||
for message in self.messages:
|
||||
yield message
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.messages[item]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.messages[key] = value
|
||||
|
||||
def __len__(self):
|
||||
return len(self.messages)
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
Generates a string representation of the conversation.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
|
||||
Example:
|
||||
Conversation id: 7d15686b-dc94-49f2-9c4b-c9eac6a1f114 user: Going to the movies tonight - any suggestions?
|
||||
bot: The Big Lebowski
|
||||
"""
|
||||
output = f"Conversation id: {self.uuid}\n"
|
||||
for message in self.messages:
|
||||
output += f"{message['role']}: {message['content']}\n"
|
||||
return output
|
||||
|
||||
def iter_texts(self):
|
||||
# This is a legacy method for backwards compatibility. It is recommended to just directly access
|
||||
# conversation.messages instead.
|
||||
for message in self.messages:
|
||||
yield message["role"] == "user", message["content"]
|
||||
|
||||
@property
|
||||
def _user_messages(self):
|
||||
# This is a legacy property for backwards compatibility. It is recommended to just directly access
|
||||
# conversation.messages instead.
|
||||
return [message["content"] for message in self.messages if message["role"] == "user"]
|
||||
|
||||
@property
|
||||
def past_user_inputs(self):
|
||||
# This is a legacy property for backwards compatibility. It is recommended to just directly access
|
||||
# conversation.messages instead. The modern class does not care about which messages are "processed"
|
||||
# or not.
|
||||
if not self._user_messages:
|
||||
return []
|
||||
# In the past, the most recent user message had to be mark_processed() before being included
|
||||
# in past_user_messages. The class essentially had a single-message buffer, representing messages that
|
||||
# had not yet been replied to. This is no longer the case, but we mimic the behaviour in this property
|
||||
# for backward compatibility.
|
||||
if self.messages[-1]["role"] != "user" or self._num_processed_user_inputs == len(self._user_messages):
|
||||
return self._user_messages
|
||||
|
||||
return self._user_messages[:-1]
|
||||
|
||||
@property
|
||||
def generated_responses(self):
|
||||
# This is a legacy property for backwards compatibility. It is recommended to just directly access
|
||||
# conversation.messages instead.
|
||||
return [message["content"] for message in self.messages if message["role"] == "assistant"]
|
||||
|
||||
@property
|
||||
def new_user_input(self):
|
||||
# This is a legacy property for backwards compatibility. It is recommended to just directly access
|
||||
# conversation.messages instead.
|
||||
return self._user_messages[-1]
|
||||
|
||||
|
||||
@add_end_docstrings(
|
||||
build_pipeline_init_args(has_tokenizer=True),
|
||||
r"""
|
||||
min_length_for_response (`int`, *optional*, defaults to 32):
|
||||
The minimum length (in number of tokens) for a response.""",
|
||||
)
|
||||
class ConversationalPipeline(Pipeline):
|
||||
"""
|
||||
Multi-turn conversational pipeline.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import pipeline, Conversation
|
||||
# Any model with a chat template can be used in a ConversationalPipeline.
|
||||
|
||||
>>> chatbot = pipeline(model="facebook/blenderbot-400M-distill")
|
||||
>>> # Conversation objects initialized with a string will treat it as a user message
|
||||
>>> conversation = Conversation("I'm looking for a movie - what's your favourite one?")
|
||||
>>> conversation = chatbot(conversation)
|
||||
>>> conversation.messages[-1]["content"]
|
||||
"I don't really have a favorite movie, but I do like action movies. What about you?"
|
||||
|
||||
>>> conversation.add_message({"role": "user", "content": "That's interesting, why do you like action movies?"})
|
||||
>>> conversation = chatbot(conversation)
|
||||
>>> conversation.messages[-1]["content"]
|
||||
" I think it's just because they're so fast-paced and action-fantastic."
|
||||
```
|
||||
|
||||
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
|
||||
|
||||
This conversational pipeline can currently be loaded from [`pipeline`] using the following task identifier:
|
||||
`"conversational"`.
|
||||
|
||||
This pipeline can be used with any model that has a [chat
|
||||
template](https://huggingface.co/docs/transformers/chat_templating) set.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
warnings.warn(
|
||||
"`ConversationalPipeline` is now deprecated, and the functionality has been moved to the standard `text-generation` pipeline, which now accepts lists of message dicts as well as strings. This class will be removed in v4.42.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
if self.tokenizer.pad_token_id is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
def _sanitize_parameters(self, min_length_for_response=None, clean_up_tokenization_spaces=None, **generate_kwargs):
|
||||
preprocess_params = {}
|
||||
forward_params = {}
|
||||
postprocess_params = {}
|
||||
|
||||
if min_length_for_response is not None:
|
||||
preprocess_params["min_length_for_response"] = min_length_for_response
|
||||
|
||||
if "max_length" in generate_kwargs:
|
||||
forward_params["max_length"] = generate_kwargs["max_length"]
|
||||
# self.max_length = generate_kwargs.get("max_length", self.model.config.max_length)
|
||||
if clean_up_tokenization_spaces is not None:
|
||||
postprocess_params["clean_up_tokenization_spaces"] = clean_up_tokenization_spaces
|
||||
|
||||
if generate_kwargs:
|
||||
forward_params.update(generate_kwargs)
|
||||
return preprocess_params, forward_params, postprocess_params
|
||||
|
||||
def __call__(self, conversations: Union[List[Dict], Conversation, List[Conversation]], num_workers=0, **kwargs):
|
||||
r"""
|
||||
Generate responses for the conversation(s) given as inputs.
|
||||
|
||||
Args:
|
||||
conversations (a [`Conversation`] or a list of [`Conversation`]):
|
||||
Conversation to generate responses for. Inputs can also be passed as a list of dictionaries with `role`
|
||||
and `content` keys - in this case, they will be converted to `Conversation` objects automatically.
|
||||
Multiple conversations in either format may be passed as a list.
|
||||
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to clean up the potential extra spaces in the text output.
|
||||
generate_kwargs:
|
||||
Additional keyword arguments to pass along to the generate method of the model (see the generate method
|
||||
corresponding to your framework [here](./main_classes/text_generation)).
|
||||
|
||||
Returns:
|
||||
[`Conversation`] or a list of [`Conversation`]: Conversation(s) with updated generated responses for those
|
||||
containing a new user input.
|
||||
"""
|
||||
# XXX: num_workers==0 is required to be backward compatible
|
||||
# Otherwise the threads will require a Conversation copy.
|
||||
# This will definitely hinder performance on GPU, but has to be opted
|
||||
# in because of this BC change.
|
||||
if isinstance(conversations, list) and isinstance(conversations[0], dict):
|
||||
conversations = Conversation(conversations)
|
||||
elif isinstance(conversations, list) and isinstance(conversations[0], list):
|
||||
conversations = [Conversation(conv) for conv in conversations]
|
||||
outputs = super().__call__(conversations, num_workers=num_workers, **kwargs)
|
||||
if isinstance(outputs, list) and len(outputs) == 1:
|
||||
return outputs[0]
|
||||
return outputs
|
||||
|
||||
def preprocess(self, conversation: Conversation, min_length_for_response=32) -> Dict[str, Any]:
|
||||
input_ids = self.tokenizer.apply_chat_template(conversation, add_generation_prompt=True)
|
||||
|
||||
if self.framework == "pt":
|
||||
input_ids = torch.LongTensor([input_ids])
|
||||
elif self.framework == "tf":
|
||||
input_ids = tf.constant([input_ids])
|
||||
return {"input_ids": input_ids, "conversation": conversation}
|
||||
|
||||
def _forward(self, model_inputs, **generate_kwargs):
|
||||
n = model_inputs["input_ids"].shape[1]
|
||||
conversation = model_inputs.pop("conversation")
|
||||
if "max_length" not in generate_kwargs and "max_new_tokens" not in generate_kwargs:
|
||||
generate_kwargs["max_new_tokens"] = 256
|
||||
output_ids = self.model.generate(**model_inputs, **generate_kwargs)
|
||||
if self.model.config.is_encoder_decoder:
|
||||
start_position = 1
|
||||
else:
|
||||
start_position = n
|
||||
return {"output_ids": output_ids[:, start_position:], "conversation": conversation}
|
||||
|
||||
def postprocess(self, model_outputs, clean_up_tokenization_spaces=True):
|
||||
output_ids = model_outputs["output_ids"]
|
||||
answer = self.tokenizer.decode(
|
||||
output_ids[0],
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
)
|
||||
conversation = model_outputs["conversation"]
|
||||
conversation.add_message({"role": "assistant", "content": answer})
|
||||
return conversation
|
|
@ -72,8 +72,6 @@ if TYPE_CHECKING:
|
|||
import tensorflow as tf
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp # noqa: F401
|
||||
from .pipelines.conversational import Conversation
|
||||
|
||||
|
||||
if is_tokenizers_available():
|
||||
from tokenizers import AddedToken
|
||||
|
@ -1684,7 +1682,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||
|
||||
def apply_chat_template(
|
||||
self,
|
||||
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]], "Conversation"],
|
||||
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
|
||||
chat_template: Optional[str] = None,
|
||||
add_generation_prompt: bool = False,
|
||||
tokenize: bool = True,
|
||||
|
@ -1703,7 +1701,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||
to the default_chat_template specified at the class level.
|
||||
|
||||
Args:
|
||||
conversation (Union[List[Dict[str, str]], List[List[Dict[str, str]]], "Conversation"]): A list of dicts
|
||||
conversation (Union[List[Dict[str, str]], List[List[Dict[str, str]]]]): A list of dicts
|
||||
with "role" and "content" keys, representing the chat history so far.
|
||||
chat_template (str, *optional*): A Jinja template to use for this conversion. If
|
||||
this is not passed, the model's default chat template will be used instead.
|
||||
|
|
|
@ -430,7 +430,6 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||
all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"conversational": BartForConditionalGeneration,
|
||||
"feature-extraction": BartModel,
|
||||
"fill-mask": BartForConditionalGeneration,
|
||||
"question-answering": BartForQuestionAnswering,
|
||||
|
@ -513,10 +512,6 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||
model.generate(input_ids, attention_mask=attention_mask)
|
||||
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
|
||||
|
||||
@unittest.skip("Does not support conversations.")
|
||||
def test_pipeline_conversational(self):
|
||||
pass
|
||||
|
||||
|
||||
def assert_tensors_close(a, b, atol=1e-12, prefix=""):
|
||||
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
|
||||
|
|
|
@ -198,7 +198,6 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, PipelineTester
|
|||
all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"conversational": TFBartForConditionalGeneration,
|
||||
"feature-extraction": TFBartModel,
|
||||
"summarization": TFBartForConditionalGeneration,
|
||||
"text-classification": TFBartForSequenceClassification,
|
||||
|
@ -343,10 +342,6 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, PipelineTester
|
|||
# check that the output for the restored model is the same
|
||||
self.assert_outputs_same(restored_model_outputs, outputs)
|
||||
|
||||
@unittest.skip("Does not support conversations.")
|
||||
def test_pipeline_conversational(self):
|
||||
pass
|
||||
|
||||
|
||||
def _long_tensor(tok_lst):
|
||||
return tf.constant(tok_lst, dtype=tf.int32)
|
||||
|
|
|
@ -253,7 +253,6 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
|||
all_generative_model_classes = (BigBirdPegasusForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"conversational": BigBirdPegasusForConditionalGeneration,
|
||||
"feature-extraction": BigBirdPegasusModel,
|
||||
"question-answering": BigBirdPegasusForQuestionAnswering,
|
||||
"summarization": BigBirdPegasusForConditionalGeneration,
|
||||
|
|
|
@ -237,7 +237,6 @@ class BlenderbotModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
|||
all_generative_model_classes = (BlenderbotForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"conversational": BlenderbotForConditionalGeneration,
|
||||
"feature-extraction": BlenderbotModel,
|
||||
"summarization": BlenderbotForConditionalGeneration,
|
||||
"text-generation": BlenderbotForCausalLM,
|
||||
|
|
|
@ -183,7 +183,6 @@ class TFBlenderbotModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Te
|
|||
all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"conversational": TFBlenderbotForConditionalGeneration,
|
||||
"feature-extraction": TFBlenderbotModel,
|
||||
"summarization": TFBlenderbotForConditionalGeneration,
|
||||
"text2text-generation": TFBlenderbotForConditionalGeneration,
|
||||
|
|
|
@ -228,7 +228,6 @@ class BlenderbotSmallModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
|
|||
all_generative_model_classes = (BlenderbotSmallForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"conversational": BlenderbotSmallForConditionalGeneration,
|
||||
"feature-extraction": BlenderbotSmallModel,
|
||||
"summarization": BlenderbotSmallForConditionalGeneration,
|
||||
"text-generation": BlenderbotSmallForCausalLM,
|
||||
|
@ -247,7 +246,7 @@ class BlenderbotSmallModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
|
|||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
):
|
||||
return pipeline_test_casse_name in ("TextGenerationPipelineTests", "ConversationalPipelineTests")
|
||||
return pipeline_test_casse_name == "TextGenerationPipelineTests"
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BlenderbotSmallModelTester(self)
|
||||
|
|
|
@ -323,7 +323,7 @@ class FlaxBlenderbotSmallModelTest(FlaxModelTesterMixin, unittest.TestCase, Flax
|
|||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
):
|
||||
return pipeline_test_casse_name in ("TextGenerationPipelineTests", "ConversationalPipelineTests")
|
||||
return pipeline_test_casse_name == "TextGenerationPipelineTests"
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FlaxBlenderbotSmallModelTester(self)
|
||||
|
|
|
@ -185,7 +185,6 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, PipelineTesterMixin, unitte
|
|||
all_generative_model_classes = (TFBlenderbotSmallForConditionalGeneration,) if is_tf_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"conversational": TFBlenderbotSmallForConditionalGeneration,
|
||||
"feature-extraction": TFBlenderbotSmallModel,
|
||||
"summarization": TFBlenderbotSmallForConditionalGeneration,
|
||||
"text2text-generation": TFBlenderbotSmallForConditionalGeneration,
|
||||
|
@ -201,7 +200,7 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, PipelineTesterMixin, unitte
|
|||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
):
|
||||
return pipeline_test_casse_name in ("TextGenerationPipelineTests", "ConversationalPipelineTests")
|
||||
return pipeline_test_casse_name == "TextGenerationPipelineTests"
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFBlenderbotSmallModelTester(self)
|
||||
|
|
|
@ -166,7 +166,6 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||
all_generative_model_classes = (FSMTForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"conversational": FSMTForConditionalGeneration,
|
||||
"feature-extraction": FSMTModel,
|
||||
"summarization": FSMTForConditionalGeneration,
|
||||
"text2text-generation": FSMTForConditionalGeneration,
|
||||
|
|
|
@ -284,7 +284,6 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||
all_generative_model_classes = (LEDForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"conversational": LEDForConditionalGeneration,
|
||||
"feature-extraction": LEDModel,
|
||||
"question-answering": LEDForQuestionAnswering,
|
||||
"summarization": LEDForConditionalGeneration,
|
||||
|
|
|
@ -197,7 +197,6 @@ class TFLEDModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
|||
all_generative_model_classes = (TFLEDForConditionalGeneration,) if is_tf_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"conversational": TFLEDForConditionalGeneration,
|
||||
"feature-extraction": TFLEDModel,
|
||||
"summarization": TFLEDForConditionalGeneration,
|
||||
"text2text-generation": TFLEDForConditionalGeneration,
|
||||
|
|
|
@ -504,7 +504,6 @@ class LongT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||
all_generative_model_classes = (LongT5ForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"conversational": LongT5ForConditionalGeneration,
|
||||
"feature-extraction": LongT5Model,
|
||||
"summarization": LongT5ForConditionalGeneration,
|
||||
"text2text-generation": LongT5ForConditionalGeneration,
|
||||
|
|
|
@ -243,7 +243,6 @@ class M2M100ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||
all_generative_model_classes = (M2M100ForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"conversational": M2M100ForConditionalGeneration,
|
||||
"feature-extraction": M2M100Model,
|
||||
"summarization": M2M100ForConditionalGeneration,
|
||||
"text2text-generation": M2M100ForConditionalGeneration,
|
||||
|
|
|
@ -311,10 +311,6 @@ class FlaxMarianModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGeneratio
|
|||
outputs = model(input_ids)
|
||||
self.assertIsNotNone(outputs)
|
||||
|
||||
@unittest.skip("Skipping for now, to fix @ArthurZ or @ydshieh")
|
||||
def test_pipeline_conversational(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_flax
|
||||
@require_sentencepiece
|
||||
|
|
|
@ -248,7 +248,6 @@ class MarianModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||
all_generative_model_classes = (MarianMTModel,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"conversational": MarianMTModel,
|
||||
"feature-extraction": MarianModel,
|
||||
"summarization": MarianMTModel,
|
||||
"text-generation": MarianForCausalLM,
|
||||
|
@ -350,10 +349,6 @@ class MarianModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||
def test_tie_word_embeddings_decoder(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Skipping for now, to fix @ArthurZ or @ydshieh")
|
||||
def test_pipeline_conversational(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
|
|
|
@ -184,7 +184,6 @@ class TFMarianModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
|||
all_generative_model_classes = (TFMarianMTModel,) if is_tf_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"conversational": TFMarianMTModel,
|
||||
"feature-extraction": TFMarianModel,
|
||||
"summarization": TFMarianMTModel,
|
||||
"text2text-generation": TFMarianMTModel,
|
||||
|
@ -208,10 +207,6 @@ class TFMarianModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
|||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||
|
||||
@unittest.skip("Skipping for now, to fix @ArthurZ or @ydshieh")
|
||||
def test_pipeline_conversational(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_tf
|
||||
class AbstractMarianIntegrationTest(unittest.TestCase):
|
||||
|
|
|
@ -240,7 +240,6 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||
all_generative_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"conversational": MBartForConditionalGeneration,
|
||||
"feature-extraction": MBartModel,
|
||||
"fill-mask": MBartForConditionalGeneration,
|
||||
"question-answering": MBartForQuestionAnswering,
|
||||
|
|
|
@ -161,7 +161,6 @@ class TFMBartModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCas
|
|||
all_generative_model_classes = (TFMBartForConditionalGeneration,) if is_tf_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"conversational": TFMBartForConditionalGeneration,
|
||||
"feature-extraction": TFMBartModel,
|
||||
"summarization": TFMBartForConditionalGeneration,
|
||||
"text2text-generation": TFMBartForConditionalGeneration,
|
||||
|
|
|
@ -555,7 +555,6 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||
all_generative_model_classes = (MT5ForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"conversational": MT5ForConditionalGeneration,
|
||||
"feature-extraction": MT5Model,
|
||||
"question-answering": MT5ForQuestionAnswering,
|
||||
"summarization": MT5ForConditionalGeneration,
|
||||
|
@ -886,10 +885,6 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
|
||||
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
|
||||
|
||||
@unittest.skip("Does not support conversations.")
|
||||
def test_pipeline_conversational(self):
|
||||
pass
|
||||
|
||||
|
||||
# Copied from tests.models.t5.test_modeling_t5.T5EncoderOnlyModelTester with T5->MT5
|
||||
class MT5EncoderOnlyModelTester:
|
||||
|
|
|
@ -421,7 +421,6 @@ class MvpModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||
all_generative_model_classes = (MvpForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"conversational": MvpForConditionalGeneration,
|
||||
"feature-extraction": MvpModel,
|
||||
"fill-mask": MvpForConditionalGeneration,
|
||||
"question-answering": MvpForQuestionAnswering,
|
||||
|
|
|
@ -250,7 +250,6 @@ class NllbMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||
all_generative_model_classes = (NllbMoeForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"conversational": NllbMoeForConditionalGeneration,
|
||||
"feature-extraction": NllbMoeModel,
|
||||
"summarization": NllbMoeForConditionalGeneration,
|
||||
"text2text-generation": NllbMoeForConditionalGeneration,
|
||||
|
|
|
@ -246,7 +246,6 @@ class PegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||
all_generative_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"conversational": PegasusForConditionalGeneration,
|
||||
"feature-extraction": PegasusModel,
|
||||
"summarization": PegasusForConditionalGeneration,
|
||||
"text-generation": PegasusForCausalLM,
|
||||
|
|
|
@ -182,7 +182,6 @@ class TFPegasusModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
|||
all_generative_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"conversational": TFPegasusForConditionalGeneration,
|
||||
"feature-extraction": TFPegasusModel,
|
||||
"summarization": TFPegasusForConditionalGeneration,
|
||||
"text2text-generation": TFPegasusForConditionalGeneration,
|
||||
|
|
|
@ -206,7 +206,6 @@ class PegasusXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||
all_generative_model_classes = (PegasusXForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"conversational": PegasusXForConditionalGeneration,
|
||||
"feature-extraction": PegasusXModel,
|
||||
"summarization": PegasusXForConditionalGeneration,
|
||||
"text2text-generation": PegasusXForConditionalGeneration,
|
||||
|
|
|
@ -227,7 +227,6 @@ class PLBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||
all_generative_model_classes = (PLBartForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"conversational": PLBartForConditionalGeneration,
|
||||
"feature-extraction": PLBartModel,
|
||||
"summarization": PLBartForConditionalGeneration,
|
||||
"text-classification": PLBartForSequenceClassification,
|
||||
|
|
|
@ -891,7 +891,6 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
|||
all_generative_model_classes = (ProphetNetForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"conversational": ProphetNetForConditionalGeneration,
|
||||
"feature-extraction": ProphetNetModel,
|
||||
"summarization": ProphetNetForConditionalGeneration,
|
||||
"text-generation": ProphetNetForCausalLM,
|
||||
|
|
|
@ -645,7 +645,6 @@ class SeamlessM4TModelWithTextInputTest(
|
|||
pipeline_model_mapping = (
|
||||
{
|
||||
"automatic-speech-recognition": SeamlessM4TForSpeechToText,
|
||||
"conversational": SeamlessM4TForTextToText,
|
||||
"feature-extraction": SeamlessM4TModel,
|
||||
"summarization": SeamlessM4TForTextToText,
|
||||
"text-to-audio": SeamlessM4TForTextToSpeech,
|
||||
|
|
|
@ -559,7 +559,6 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
|
|||
all_generative_model_classes = (SwitchTransformersForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"conversational": SwitchTransformersForConditionalGeneration,
|
||||
"feature-extraction": SwitchTransformersModel,
|
||||
"summarization": SwitchTransformersForConditionalGeneration,
|
||||
"text2text-generation": SwitchTransformersForConditionalGeneration,
|
||||
|
|
|
@ -558,7 +558,6 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"conversational": T5ForConditionalGeneration,
|
||||
"feature-extraction": T5Model,
|
||||
"question-answering": T5ForQuestionAnswering,
|
||||
"summarization": T5ForConditionalGeneration,
|
||||
|
@ -889,10 +888,6 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
|
||||
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
|
||||
|
||||
@unittest.skip("Does not support conversations.")
|
||||
def test_pipeline_conversational(self):
|
||||
pass
|
||||
|
||||
|
||||
class T5EncoderOnlyModelTester:
|
||||
def __init__(
|
||||
|
|
|
@ -248,7 +248,6 @@ class TFT5ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||
all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"conversational": TFT5ForConditionalGeneration,
|
||||
"feature-extraction": TFT5Model,
|
||||
"summarization": TFT5ForConditionalGeneration,
|
||||
"text2text-generation": TFT5ForConditionalGeneration,
|
||||
|
@ -314,10 +313,6 @@ class TFT5ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||
def test_keras_save_load(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Does not support conversations.")
|
||||
def test_pipeline_conversational(self):
|
||||
pass
|
||||
|
||||
|
||||
class TFT5EncoderOnlyModelTester:
|
||||
def __init__(
|
||||
|
@ -611,10 +606,6 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
|
|||
expected_output_string = ["Ich liebe es so sehr!", "die Transformatoren sind wirklich erstaunlich"]
|
||||
self.assertListEqual(expected_output_string, output_strings)
|
||||
|
||||
@unittest.skip("Does not support conversations.")
|
||||
def test_pipeline_conversational(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_tf
|
||||
@require_sentencepiece
|
||||
|
|
|
@ -297,7 +297,6 @@ class UMT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||
all_generative_model_classes = (UMT5ForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"conversational": UMT5ForConditionalGeneration,
|
||||
"feature-extraction": UMT5Model,
|
||||
"question-answering": UMT5ForQuestionAnswering,
|
||||
"summarization": UMT5ForConditionalGeneration,
|
||||
|
|
|
@ -1,439 +0,0 @@
|
|||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
BlenderbotSmallForConditionalGeneration,
|
||||
BlenderbotSmallTokenizer,
|
||||
Conversation,
|
||||
ConversationalPipeline,
|
||||
TFAutoModelForCausalLM,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
backend_empty_cache,
|
||||
is_pipeline_test,
|
||||
is_torch_available,
|
||||
require_tf,
|
||||
require_torch,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from .test_pipelines_common import ANY
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
class ConversationalPipelineTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||
gc.collect()
|
||||
if is_torch_available():
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
model_mapping = dict(
|
||||
list(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items())
|
||||
if MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
else [] + list(MODEL_FOR_CAUSAL_LM_MAPPING.items())
|
||||
if MODEL_FOR_CAUSAL_LM_MAPPING
|
||||
else []
|
||||
)
|
||||
tf_model_mapping = dict(
|
||||
list(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items())
|
||||
if TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
else [] + list(TF_MODEL_FOR_CAUSAL_LM_MAPPING.items())
|
||||
if TF_MODEL_FOR_CAUSAL_LM_MAPPING
|
||||
else []
|
||||
)
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor):
|
||||
conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer)
|
||||
return conversation_agent, [Conversation("Hi there!")]
|
||||
|
||||
def run_pipeline_test(self, conversation_agent, _):
|
||||
# Simple
|
||||
outputs = conversation_agent(Conversation("Hi there!"), max_new_tokens=5)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]),
|
||||
)
|
||||
|
||||
# Single list
|
||||
outputs = conversation_agent([Conversation("Hi there!")], max_new_tokens=5)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]),
|
||||
)
|
||||
|
||||
# Batch
|
||||
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
|
||||
conversation_2 = Conversation("What's the last book you have read?")
|
||||
self.assertEqual(len(conversation_1), 1)
|
||||
self.assertEqual(len(conversation_2), 1)
|
||||
|
||||
outputs = conversation_agent([conversation_1, conversation_2], max_new_tokens=5)
|
||||
self.assertEqual(outputs, [conversation_1, conversation_2])
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
Conversation(
|
||||
[
|
||||
{"role": "user", "content": "Going to the movies tonight - any suggestions?"},
|
||||
{"role": "assistant", "content": ANY(str)},
|
||||
],
|
||||
),
|
||||
Conversation(
|
||||
[
|
||||
{"role": "user", "content": "What's the last book you have read?"},
|
||||
{"role": "assistant", "content": ANY(str)},
|
||||
]
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# One conversation with history
|
||||
conversation_2.add_message({"role": "user", "content": "Why do you recommend it?"})
|
||||
outputs = conversation_agent(conversation_2, max_new_tokens=5)
|
||||
self.assertEqual(outputs, conversation_2)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
Conversation(
|
||||
[
|
||||
{"role": "user", "content": "What's the last book you have read?"},
|
||||
{"role": "assistant", "content": ANY(str)},
|
||||
{"role": "user", "content": "Why do you recommend it?"},
|
||||
{"role": "assistant", "content": ANY(str)},
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_integration_torch_conversation(self):
|
||||
# When
|
||||
conversation_agent = pipeline(task="conversational", device=torch_device)
|
||||
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
|
||||
conversation_2 = Conversation("What's the last book you have read?")
|
||||
# Then
|
||||
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
||||
self.assertEqual(len(conversation_2.past_user_inputs), 0)
|
||||
# When
|
||||
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
|
||||
# Then
|
||||
self.assertEqual(result, [conversation_1, conversation_2])
|
||||
self.assertEqual(len(result[0].past_user_inputs), 1)
|
||||
self.assertEqual(len(result[1].past_user_inputs), 1)
|
||||
self.assertEqual(len(result[0].generated_responses), 1)
|
||||
self.assertEqual(len(result[1].generated_responses), 1)
|
||||
self.assertEqual(result[0].past_user_inputs[0], "Going to the movies tonight - any suggestions?")
|
||||
self.assertEqual(result[0].generated_responses[0], "The Big Lebowski")
|
||||
self.assertEqual(result[1].past_user_inputs[0], "What's the last book you have read?")
|
||||
self.assertEqual(result[1].generated_responses[0], "The Last Question")
|
||||
# When
|
||||
conversation_2.add_user_input("Why do you recommend it?")
|
||||
result = conversation_agent(conversation_2, do_sample=False, max_length=1000)
|
||||
# Then
|
||||
self.assertEqual(result, conversation_2)
|
||||
self.assertEqual(len(result.past_user_inputs), 2)
|
||||
self.assertEqual(len(result.generated_responses), 2)
|
||||
self.assertEqual(result.past_user_inputs[1], "Why do you recommend it?")
|
||||
self.assertEqual(result.generated_responses[1], "It's a good book.")
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_integration_torch_conversation_truncated_history(self):
|
||||
# When
|
||||
conversation_agent = pipeline(task="conversational", min_length_for_response=24, device=torch_device)
|
||||
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
|
||||
# Then
|
||||
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
||||
# When
|
||||
result = conversation_agent(conversation_1, do_sample=False, max_length=36)
|
||||
# Then
|
||||
self.assertEqual(result, conversation_1)
|
||||
self.assertEqual(len(result.past_user_inputs), 1)
|
||||
self.assertEqual(len(result.generated_responses), 1)
|
||||
self.assertEqual(result.past_user_inputs[0], "Going to the movies tonight - any suggestions?")
|
||||
self.assertEqual(result.generated_responses[0], "The Big Lebowski")
|
||||
# When
|
||||
conversation_1.add_user_input("Is it an action movie?")
|
||||
result = conversation_agent(conversation_1, do_sample=False, max_length=36)
|
||||
# Then
|
||||
self.assertEqual(result, conversation_1)
|
||||
self.assertEqual(len(result.past_user_inputs), 2)
|
||||
self.assertEqual(len(result.generated_responses), 2)
|
||||
self.assertEqual(result.past_user_inputs[1], "Is it an action movie?")
|
||||
self.assertEqual(result.generated_responses[1], "It's a comedy.")
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
|
||||
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
|
||||
conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer)
|
||||
conversation = Conversation("hello")
|
||||
output = conversation_agent(conversation)
|
||||
self.assertEqual(output, Conversation(past_user_inputs=["hello"], generated_responses=["Hi"]))
|
||||
|
||||
@require_tf
|
||||
def test_small_model_tf(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
|
||||
model = TFAutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
|
||||
conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer)
|
||||
conversation = Conversation("hello")
|
||||
output = conversation_agent(conversation)
|
||||
self.assertEqual(output, Conversation(past_user_inputs=["hello"], generated_responses=["Hi"]))
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_integration_torch_conversation_dialogpt_input_ids(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
|
||||
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
|
||||
conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer)
|
||||
|
||||
conversation_1 = Conversation("hello")
|
||||
inputs = conversation_agent.preprocess(conversation_1)
|
||||
self.assertEqual(inputs["input_ids"].tolist(), [[31373, 50256]])
|
||||
|
||||
conversation_2 = Conversation("how are you ?", past_user_inputs=["hello"], generated_responses=["Hi there!"])
|
||||
inputs = conversation_agent.preprocess(conversation_2)
|
||||
self.assertEqual(
|
||||
inputs["input_ids"].tolist(), [[31373, 50256, 17250, 612, 0, 50256, 4919, 389, 345, 5633, 50256]]
|
||||
)
|
||||
|
||||
@unittest.skip("Model is curently gated")
|
||||
@require_torch
|
||||
@slow
|
||||
def test_integration_torch_conversation_llama2_input_ids(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", use_default_system_prompt=True)
|
||||
|
||||
conversation = Conversation(
|
||||
"What is so great about #1?",
|
||||
past_user_inputs=["I am going to Paris, what should I see?"],
|
||||
generated_responses=[
|
||||
"""\
|
||||
Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere. Here are some of the top attractions to see in Paris:
|
||||
|
||||
1. The Eiffel Tower: The iconic Eiffel Tower is one of the most recognizable landmarks in the world and offers breathtaking views of the city.
|
||||
2. The Louvre Museum: The Louvre is one of the world's largest and most famous museums, housing an impressive collection of art and artifacts, including the Mona Lisa.
|
||||
3. Notre-Dame Cathedral: This beautiful cathedral is one of the most famous landmarks in Paris and is known for its Gothic architecture and stunning stained glass windows.
|
||||
|
||||
These are just a few of the many attractions that Paris has to offer. With so much to see and do, it's no wonder that Paris is one of the most popular tourist destinations in the world."""
|
||||
],
|
||||
)
|
||||
inputs = tokenizer._build_conversation_input_ids(conversation)
|
||||
EXPECTED_INPUTS_IDS = [ 1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492, 526, 263, 8444, 29892, 3390, 1319, 322, 15993, 20255, 29889, 29849, 1234, 408, 1371, 3730, 408, 1950, 29892, 1550, 1641, 9109, 29889, 29871, 3575, 6089, 881, 451, 3160, 738, 10311, 1319, 29892, 443, 621, 936, 29892, 11021, 391, 29892, 7916, 391, 29892, 304, 27375, 29892, 18215, 29892, 470, 27302, 2793, 29889, 3529, 9801, 393, 596, 20890, 526, 5374, 635, 443, 5365, 1463, 322, 6374, 297, 5469, 29889, 13, 13, 3644, 263, 1139, 947, 451, 1207, 738, 4060, 29892, 470, 338, 451, 2114, 1474, 16165, 261, 296, 29892, 5649, 2020, 2012, 310, 22862, 1554, 451, 1959, 29889, 960, 366, 1016, 29915, 29873, 1073, 278, 1234, 304, 263, 1139, 29892, 3113, 1016, 29915, 29873, 6232, 2089, 2472, 29889, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 29902, 626, 2675, 304, 3681, 29892, 825, 881, 306, 1074, 29973, 518, 29914, 25580, 29962, 3681, 29892, 278, 7483, 310, 3444, 29892, 338, 2998, 363, 967, 380, 27389, 11258, 29892, 1616, 19133, 29879, 29892, 15839, 2982, 22848, 29892, 322, 6017, 7716, 25005, 29889, 2266, 526, 777, 310, 278, 2246, 19650, 1953, 304, 1074, 297, 3681, 29901, 13, 13, 29896, 29889, 450, 382, 2593, 295, 23615, 29901, 450, 9849, 293, 382, 2593, 295, 23615, 338, 697, 310, 278, 1556, 5936, 13902, 2982, 22848, 297, 278, 3186, 322, 16688, 2078, 271, 400, 5086, 8386, 310, 278, 4272, 29889, 13, 29906, 29889, 450, 4562, 12675, 6838, 29901, 450, 4562, 12675, 338, 697, 310, 278, 3186, 29915, 29879, 10150, 322, 1556, 13834, 19133, 29879, 29892, 27261, 385, 21210, 573, 4333, 310, 1616, 322, 24238, 29879, 29892, 3704, 278, 2598, 29874, 29420, 29889, 13, 29941, 29889, 24337, 29899, 29928, 420, 315, 21471, 29901, 910, 9560, 274, 21471, 338, 697, 310, 278, 1556, 13834, 2982, 22848, 297, 3681, 322, 338, 2998, 363, 967, 22883, 293, 11258, 322, 380, 27389, 380, 7114, 12917, 5417, 29889, 13, 13, 1349, 968, 526, 925, 263, 2846, 310, 278, 1784, 19650, 1953, 393, 3681, 756, 304, 5957, 29889, 2973, 577, 1568, 304, 1074, 322, 437, 29892, 372, 29915, 29879, 694, 4997, 393, 3681, 338, 697, 310, 278, 1556, 5972, 6282, 391, 15422, 800, 297, 278, 3186, 29889, 29871, 2, 1, 518, 25580, 29962, 1724, 338, 577, 2107, 1048, 396, 29896, 29973, 518, 29914, 25580, 29962] # fmt: skip
|
||||
self.assertEqual(inputs, EXPECTED_INPUTS_IDS)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
|
||||
conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer)
|
||||
EXPECTED_TEXT = "what topic you want to focus on and create content around it. This will help you stand out from other creators and attract a specific audience.\n\nStep 2: Set Up Your Channel\nCreate your YouTube account and customize your channel with your branding and logo. Make sure your channel name and profile picture are consistent with your niche.\n\nStep 3: Plan Your Content\nDevelop a content strategy that includes the type of content you want to create, how often you will post, and when you will post. Consider creating a content calendar to help you stay organized.\n\nStep 4: Invest in Quality Equipment\nInvest in good quality camera and microphone equipment to ensure your videos look and sound professional. You don't need to break the bank, but investing in good equipment will make a big difference in the quality of your videos.\n\nStep 5: Optimize Your Videos for Search\nUse keywords in your video titles, descriptions, and tags to help people find your videos when they search for topics related to your niche"
|
||||
conversation = Conversation(
|
||||
"<<SYS>>\n Only answer with emojis, and charades\n<</SYS>>\n\nHow can I build a house in 10 steps?"
|
||||
)
|
||||
result = conversation_agent(conversation)
|
||||
self.assertEqual(result.generated_responses[-1], EXPECTED_TEXT)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_integration_torch_conversation_blenderbot_400M_input_ids(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot-400M-distill")
|
||||
conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer)
|
||||
|
||||
# test1
|
||||
conversation_1 = Conversation("hello")
|
||||
inputs = conversation_agent.preprocess(conversation_1)
|
||||
self.assertEqual(inputs["input_ids"].tolist(), [[1710, 86, 2]])
|
||||
|
||||
# test2
|
||||
conversation_1 = Conversation(
|
||||
"I like lasagne.",
|
||||
past_user_inputs=["hello"],
|
||||
generated_responses=[
|
||||
" Do you like lasagne? It is a traditional Italian dish consisting of a shepherd's pie."
|
||||
],
|
||||
)
|
||||
inputs = conversation_agent.preprocess(conversation_1)
|
||||
self.assertEqual(
|
||||
inputs["input_ids"].tolist(),
|
||||
[
|
||||
# This should be compared with the same conversation on ParlAI `safe_interactive` demo.
|
||||
[
|
||||
1710, # hello
|
||||
86,
|
||||
228, # Double space
|
||||
228,
|
||||
946,
|
||||
304,
|
||||
398,
|
||||
6881,
|
||||
558,
|
||||
964,
|
||||
38,
|
||||
452,
|
||||
315,
|
||||
265,
|
||||
6252,
|
||||
452,
|
||||
322,
|
||||
968,
|
||||
6884,
|
||||
3146,
|
||||
278,
|
||||
306,
|
||||
265,
|
||||
617,
|
||||
87,
|
||||
388,
|
||||
75,
|
||||
341,
|
||||
286,
|
||||
521,
|
||||
21,
|
||||
228, # Double space
|
||||
228,
|
||||
281, # I like lasagne.
|
||||
398,
|
||||
6881,
|
||||
558,
|
||||
964,
|
||||
21,
|
||||
2, # EOS
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_integration_torch_conversation_blenderbot_400M(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot-400M-distill")
|
||||
conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer)
|
||||
|
||||
conversation_1 = Conversation("hello")
|
||||
result = conversation_agent(
|
||||
conversation_1,
|
||||
)
|
||||
self.assertEqual(
|
||||
result.generated_responses[0],
|
||||
# ParlAI implementation output, we have a different one, but it's our
|
||||
# second best, you can check by using num_return_sequences=10
|
||||
# " Hello! How are you? I'm just getting ready to go to work, how about you?",
|
||||
" Hello! How are you doing today? I just got back from a walk with my dog.",
|
||||
)
|
||||
|
||||
conversation_1 = Conversation("Lasagne hello")
|
||||
result = conversation_agent(conversation_1, encoder_no_repeat_ngram_size=3)
|
||||
self.assertEqual(
|
||||
result.generated_responses[0],
|
||||
" Do you like lasagne? It is a traditional Italian dish consisting of a shepherd's pie.",
|
||||
)
|
||||
|
||||
conversation_1 = Conversation(
|
||||
"Lasagne hello Lasagne is my favorite Italian dish. Do you like lasagne? I like lasagne."
|
||||
)
|
||||
result = conversation_agent(
|
||||
conversation_1,
|
||||
encoder_no_repeat_ngram_size=3,
|
||||
)
|
||||
self.assertEqual(
|
||||
result.generated_responses[0],
|
||||
" Me too. I like how it can be topped with vegetables, meats, and condiments.",
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_integration_torch_conversation_encoder_decoder(self):
|
||||
# When
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot_small-90M")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot_small-90M")
|
||||
conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer, device=torch_device)
|
||||
|
||||
conversation_1 = Conversation("My name is Sarah and I live in London")
|
||||
conversation_2 = Conversation("Going to the movies tonight, What movie would you recommend? ")
|
||||
# Then
|
||||
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
||||
self.assertEqual(len(conversation_2.past_user_inputs), 0)
|
||||
# When
|
||||
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
|
||||
# Then
|
||||
self.assertEqual(result, [conversation_1, conversation_2])
|
||||
self.assertEqual(len(result[0].past_user_inputs), 1)
|
||||
self.assertEqual(len(result[1].past_user_inputs), 1)
|
||||
self.assertEqual(len(result[0].generated_responses), 1)
|
||||
self.assertEqual(len(result[1].generated_responses), 1)
|
||||
self.assertEqual(result[0].past_user_inputs[0], "My name is Sarah and I live in London")
|
||||
self.assertEqual(
|
||||
result[0].generated_responses[0],
|
||||
"hi sarah, i live in london as well. do you have any plans for the weekend?",
|
||||
)
|
||||
self.assertEqual(
|
||||
result[1].past_user_inputs[0], "Going to the movies tonight, What movie would you recommend? "
|
||||
)
|
||||
self.assertEqual(
|
||||
result[1].generated_responses[0], "i don't know... i'm not really sure. what movie are you going to see?"
|
||||
)
|
||||
# When
|
||||
conversation_1.add_user_input("Not yet, what about you?")
|
||||
conversation_2.add_user_input("What's your name?")
|
||||
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
|
||||
# Then
|
||||
self.assertEqual(result, [conversation_1, conversation_2])
|
||||
self.assertEqual(len(result[0].past_user_inputs), 2)
|
||||
self.assertEqual(len(result[1].past_user_inputs), 2)
|
||||
self.assertEqual(len(result[0].generated_responses), 2)
|
||||
self.assertEqual(len(result[1].generated_responses), 2)
|
||||
self.assertEqual(result[0].past_user_inputs[1], "Not yet, what about you?")
|
||||
self.assertEqual(result[0].generated_responses[1], "i don't have any plans yet. i'm not sure what to do yet.")
|
||||
self.assertEqual(result[1].past_user_inputs[1], "What's your name?")
|
||||
self.assertEqual(result[1].generated_responses[1], "i don't have a name, but i'm going to see a horror movie.")
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_from_pipeline_conversation(self):
|
||||
model_id = "facebook/blenderbot_small-90M"
|
||||
|
||||
# from model id
|
||||
conversation_agent_from_model_id = pipeline("conversational", model=model_id, tokenizer=model_id)
|
||||
|
||||
# from model object
|
||||
model = BlenderbotSmallForConditionalGeneration.from_pretrained(model_id)
|
||||
tokenizer = BlenderbotSmallTokenizer.from_pretrained(model_id)
|
||||
conversation_agent_from_model = pipeline("conversational", model=model, tokenizer=tokenizer)
|
||||
|
||||
conversation = Conversation("My name is Sarah and I live in London")
|
||||
conversation_copy = Conversation("My name is Sarah and I live in London")
|
||||
|
||||
result_model_id = conversation_agent_from_model_id([conversation])
|
||||
result_model = conversation_agent_from_model([conversation_copy])
|
||||
|
||||
# check for equality
|
||||
self.assertEqual(
|
||||
result_model_id.generated_responses[0],
|
||||
"hi sarah, i live in london as well. do you have any plans for the weekend?",
|
||||
)
|
||||
self.assertEqual(
|
||||
result_model_id.generated_responses[0],
|
||||
result_model.generated_responses[0],
|
||||
)
|
|
@ -33,7 +33,6 @@ from transformers.utils import direct_transformers_import, logging
|
|||
|
||||
from .pipelines.test_pipelines_audio_classification import AudioClassificationPipelineTests
|
||||
from .pipelines.test_pipelines_automatic_speech_recognition import AutomaticSpeechRecognitionPipelineTests
|
||||
from .pipelines.test_pipelines_conversational import ConversationalPipelineTests
|
||||
from .pipelines.test_pipelines_depth_estimation import DepthEstimationPipelineTests
|
||||
from .pipelines.test_pipelines_document_question_answering import DocumentQuestionAnsweringPipelineTests
|
||||
from .pipelines.test_pipelines_feature_extraction import FeatureExtractionPipelineTests
|
||||
|
@ -65,7 +64,6 @@ from .pipelines.test_pipelines_zero_shot_object_detection import ZeroShotObjectD
|
|||
pipeline_test_mapping = {
|
||||
"audio-classification": {"test": AudioClassificationPipelineTests},
|
||||
"automatic-speech-recognition": {"test": AutomaticSpeechRecognitionPipelineTests},
|
||||
"conversational": {"test": ConversationalPipelineTests},
|
||||
"depth-estimation": {"test": DepthEstimationPipelineTests},
|
||||
"document-question-answering": {"test": DocumentQuestionAnsweringPipelineTests},
|
||||
"feature-extraction": {"test": FeatureExtractionPipelineTests},
|
||||
|
@ -314,10 +312,6 @@ class PipelineTesterMixin:
|
|||
yield copy.deepcopy(random.choice(examples))
|
||||
|
||||
out = []
|
||||
if task == "conversational":
|
||||
for item in pipeline(data(10), batch_size=4, max_new_tokens=5):
|
||||
out.append(item)
|
||||
else:
|
||||
for item in pipeline(data(10), batch_size=4):
|
||||
out.append(item)
|
||||
self.assertEqual(len(out), 10)
|
||||
|
@ -332,10 +326,6 @@ class PipelineTesterMixin:
|
|||
def test_pipeline_automatic_speech_recognition(self):
|
||||
self.run_task_tests(task="automatic-speech-recognition")
|
||||
|
||||
@is_pipeline_test
|
||||
def test_pipeline_conversational(self):
|
||||
self.run_task_tests(task="conversational")
|
||||
|
||||
@is_pipeline_test
|
||||
@require_vision
|
||||
@require_timm
|
||||
|
|
|
@ -128,7 +128,6 @@ OBJECTS_TO_IGNORE = [
|
|||
"ConvBertTokenizerFast",
|
||||
"ConvNextConfig",
|
||||
"ConvNextV2Config",
|
||||
"ConversationalPipeline",
|
||||
"CpmAntTokenizer",
|
||||
"CvtConfig",
|
||||
"CvtModel",
|
||||
|
|
|
@ -918,7 +918,6 @@ src/transformers/pipelines/audio_classification.py
|
|||
src/transformers/pipelines/audio_utils.py
|
||||
src/transformers/pipelines/automatic_speech_recognition.py
|
||||
src/transformers/pipelines/base.py
|
||||
src/transformers/pipelines/conversational.py
|
||||
src/transformers/pipelines/depth_estimation.py
|
||||
src/transformers/pipelines/document_question_answering.py
|
||||
src/transformers/pipelines/feature_extraction.py
|
||||
|
|
Loading…
Reference in New Issue