Making Conversation possible to create directly a full conversation (#9434)
* Cleaning up conversation tests. * Adding tests that don't require downloading models + conversation can be fully created from static state. * Making tests non flaky (by fixing generation length) * Bumping isort version. * Doc cleanup. * Remove unused test in this PR. * Torch import guard for TF. * Missing torch guard. * Small mistake in doc. * Actual uses `_history` and `_index` cache. + remove dead enumerate + improve warning message. * Update src/transformers/pipelines/conversational.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/pipelines/conversational.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/pipelines/conversational.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Adding comments and cleaner code to address history copy. * Improving pipeline name in tests. * Change tokenizer to a real one (still created at runtime with no external dependency) * Simplify DummyTok, reverse changes on tokenization. * Removing DummyTok. Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
4fbcf8ea49
commit
02e05fb0a5
|
@ -33,6 +33,14 @@ class Conversation:
|
||||||
conversation_id (:obj:`uuid.UUID`, `optional`):
|
conversation_id (:obj:`uuid.UUID`, `optional`):
|
||||||
Unique identifier for the conversation. If not provided, a random UUID4 id will be assigned to the
|
Unique identifier for the conversation. If not provided, a random UUID4 id will be assigned to the
|
||||||
conversation.
|
conversation.
|
||||||
|
past_user_inputs (:obj:`List[str]`, `optional`):
|
||||||
|
Eventual past history of the conversation of the user. You don't need to pass it manually if you use the
|
||||||
|
pipeline interactively but if you want to recreate history you need to set both :obj:`past_user_inputs` and
|
||||||
|
:obj:`generated_responses` with equal length lists of strings
|
||||||
|
generated_responses (:obj:`List[str]`, `optional`):
|
||||||
|
Eventual past history of the conversation of the model. You don't need to pass it manually if you use the
|
||||||
|
pipeline interactively but if you want to recreate history you need to set both :obj:`past_user_inputs` and
|
||||||
|
:obj:`generated_responses` with equal length lists of strings
|
||||||
|
|
||||||
Usage::
|
Usage::
|
||||||
|
|
||||||
|
@ -47,14 +55,33 @@ class Conversation:
|
||||||
conversation.add_user_input("Is it good?")
|
conversation.add_user_input("Is it good?")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, text: str = None, conversation_id: uuid.UUID = None):
|
def __init__(
|
||||||
|
self, text: str = None, conversation_id: uuid.UUID = None, past_user_inputs=None, generated_responses=None
|
||||||
|
):
|
||||||
if not conversation_id:
|
if not conversation_id:
|
||||||
conversation_id = uuid.uuid4()
|
conversation_id = uuid.uuid4()
|
||||||
|
if past_user_inputs is None:
|
||||||
|
past_user_inputs = []
|
||||||
|
if generated_responses is None:
|
||||||
|
generated_responses = []
|
||||||
|
|
||||||
self.uuid: uuid.UUID = conversation_id
|
self.uuid: uuid.UUID = conversation_id
|
||||||
self.past_user_inputs: List[str] = []
|
self.past_user_inputs: List[str] = past_user_inputs
|
||||||
self.generated_responses: List[str] = []
|
self.generated_responses: List[str] = generated_responses
|
||||||
self.history: List[int] = []
|
|
||||||
self.new_user_input: Optional[str] = text
|
self.new_user_input: Optional[str] = text
|
||||||
|
self._index: int = 0
|
||||||
|
self._history: List[int] = []
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if not isinstance(other, Conversation):
|
||||||
|
return False
|
||||||
|
if self.uuid == other.uuid:
|
||||||
|
return True
|
||||||
|
return (
|
||||||
|
self.new_user_input == other.new_user_input
|
||||||
|
and self.past_user_inputs == other.past_user_inputs
|
||||||
|
and self.generated_responses == other.generated_responses
|
||||||
|
)
|
||||||
|
|
||||||
def add_user_input(self, text: str, overwrite: bool = False):
|
def add_user_input(self, text: str, overwrite: bool = False):
|
||||||
"""
|
"""
|
||||||
|
@ -100,16 +127,6 @@ class Conversation:
|
||||||
"""
|
"""
|
||||||
self.generated_responses.append(response)
|
self.generated_responses.append(response)
|
||||||
|
|
||||||
def set_history(self, history: List[int]):
|
|
||||||
"""
|
|
||||||
Updates the value of the history of the conversation. The history is represented by a list of :obj:`token_ids`.
|
|
||||||
The history is used by the model to generate responses based on the previous conversation turns.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
history (:obj:`List[int]`): History of tokens provided and generated for this conversation.
|
|
||||||
"""
|
|
||||||
self.history = history
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
"""
|
"""
|
||||||
Generates a string representation of the conversation.
|
Generates a string representation of the conversation.
|
||||||
|
@ -167,12 +184,40 @@ class ConversationalPipeline(Pipeline):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
# We need at least an eos_token
|
# We need at least an eos_token
|
||||||
assert self.tokenizer.eos_token_id is not None, "DialoguePipeline tokenizer should have an EOS token set"
|
assert self.tokenizer.eos_token_id is not None, "ConversationalPipeline tokenizer should have an EOS token set"
|
||||||
if self.tokenizer.pad_token_id is None:
|
if self.tokenizer.pad_token_id is None:
|
||||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||||
|
|
||||||
self.min_length_for_response = min_length_for_response
|
self.min_length_for_response = min_length_for_response
|
||||||
|
|
||||||
|
def _get_history(self, conversation):
|
||||||
|
"""
|
||||||
|
Private function (subject to change) that simply tokenizes and concatenates past inputs. Also saves that
|
||||||
|
tokenization into the conversation state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation (:class:`~transformers.Conversation`)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`List[int]`: The list of tokens for the past input of that conversation.
|
||||||
|
"""
|
||||||
|
# Make a copy to prevent messing cache up if there's an error
|
||||||
|
# within this function
|
||||||
|
history = conversation._history.copy()
|
||||||
|
index = conversation._index
|
||||||
|
new_index = index
|
||||||
|
for i, (past_user_input, generated_response) in enumerate(
|
||||||
|
zip(conversation.past_user_inputs[index:], conversation.generated_responses[index:])
|
||||||
|
):
|
||||||
|
for el in (past_user_input, generated_response):
|
||||||
|
new_history = self._parse_and_tokenize([el])[0]
|
||||||
|
history.extend(new_history)
|
||||||
|
new_index = i + index + 1
|
||||||
|
conversation._index = new_index
|
||||||
|
conversation._history = history
|
||||||
|
# Hand back a copy to caller so they can't accidently modify our cache.
|
||||||
|
return history.copy()
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
conversations: Union[Conversation, List[Conversation]],
|
conversations: Union[Conversation, List[Conversation]],
|
||||||
|
@ -220,7 +265,7 @@ class ConversationalPipeline(Pipeline):
|
||||||
with self.device_placement():
|
with self.device_placement():
|
||||||
|
|
||||||
inputs = self._parse_and_tokenize([conversation.new_user_input for conversation in conversations])
|
inputs = self._parse_and_tokenize([conversation.new_user_input for conversation in conversations])
|
||||||
histories = [conversation.history for conversation in conversations]
|
histories = [self._get_history(conversation) for conversation in conversations]
|
||||||
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
|
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
|
||||||
inputs = self._concat_inputs_history(inputs, histories, max_length)
|
inputs = self._concat_inputs_history(inputs, histories, max_length)
|
||||||
|
|
||||||
|
@ -266,7 +311,6 @@ class ConversationalPipeline(Pipeline):
|
||||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
conversation.set_history(history[conversation_index])
|
|
||||||
output.append(conversation)
|
output.append(conversation)
|
||||||
if len(output) == 1:
|
if len(output) == 1:
|
||||||
return output[0]
|
return output[0]
|
||||||
|
@ -333,7 +377,9 @@ class ConversationalPipeline(Pipeline):
|
||||||
if cutoff_eos_index == 0 or cutoff_eos_index == len(new_input) - 1:
|
if cutoff_eos_index == 0 or cutoff_eos_index == len(new_input) - 1:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
new_input = new_input[cutoff_eos_index + 1 :]
|
logger.warning(
|
||||||
|
f"Cutting history off because it's too long ({len(new_input)} > {max_length - self.min_length_for_response}) for underlying model"
|
||||||
|
)
|
||||||
outputs.append(new_input)
|
outputs.append(new_input)
|
||||||
padded_outputs = self.tokenizer.pad(
|
padded_outputs = self.tokenizer.pad(
|
||||||
{"input_ids": outputs}, padding="longest", return_attention_mask=True, return_tensors=self.framework
|
{"input_ids": outputs}, padding="longest", return_attention_mask=True, return_tensors=self.framework
|
||||||
|
|
|
@ -14,15 +14,177 @@
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Conversation, ConversationalPipeline, pipeline
|
from transformers import (
|
||||||
|
AutoModelForSeq2SeqLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
Conversation,
|
||||||
|
ConversationalPipeline,
|
||||||
|
is_torch_available,
|
||||||
|
pipeline,
|
||||||
|
)
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_pipelines_common import MonoInputPipelineCommonMixin
|
from .test_pipelines_common import MonoInputPipelineCommonMixin
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers.models.gpt2 import GPT2Config, GPT2LMHeadModel
|
||||||
|
|
||||||
DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0
|
DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleConversationPipelineTests(unittest.TestCase):
|
||||||
|
def get_pipeline(self):
|
||||||
|
# When
|
||||||
|
config = GPT2Config(
|
||||||
|
vocab_size=263,
|
||||||
|
n_ctx=128,
|
||||||
|
max_length=128,
|
||||||
|
n_embd=64,
|
||||||
|
n_layer=1,
|
||||||
|
n_head=8,
|
||||||
|
bos_token_id=256,
|
||||||
|
eos_token_id=257,
|
||||||
|
)
|
||||||
|
model = GPT2LMHeadModel(config)
|
||||||
|
# Force model output to be L
|
||||||
|
V, D = model.lm_head.weight.shape
|
||||||
|
bias = torch.zeros(V, requires_grad=True)
|
||||||
|
bias[76] = 1
|
||||||
|
|
||||||
|
model.lm_head.bias = torch.nn.Parameter(bias)
|
||||||
|
|
||||||
|
# # Created with:
|
||||||
|
# import tempfile
|
||||||
|
|
||||||
|
# from tokenizers import Tokenizer, models
|
||||||
|
# from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
||||||
|
|
||||||
|
# vocab = [(chr(i), i) for i in range(256)]
|
||||||
|
# tokenizer = Tokenizer(models.Unigram(vocab))
|
||||||
|
# with tempfile.NamedTemporaryFile() as f:
|
||||||
|
# tokenizer.save(f.name)
|
||||||
|
# real_tokenizer = PreTrainedTokenizerFast(tokenizer_file=f.name, eos_token="<eos>", bos_token="<bos>")
|
||||||
|
|
||||||
|
# real_tokenizer._tokenizer.save("dummy.json")
|
||||||
|
# Special tokens are automatically added at load time.
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("Narsil/small_conversational_test")
|
||||||
|
conversation_agent = pipeline(
|
||||||
|
task="conversational", device=DEFAULT_DEVICE_NUM, model=model, tokenizer=tokenizer
|
||||||
|
)
|
||||||
|
return conversation_agent
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_integration_torch_conversation(self):
|
||||||
|
conversation_agent = self.get_pipeline()
|
||||||
|
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
|
||||||
|
conversation_2 = Conversation("What's the last book you have read?")
|
||||||
|
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
||||||
|
self.assertEqual(len(conversation_2.past_user_inputs), 0)
|
||||||
|
|
||||||
|
with self.assertLogs("transformers", level="WARNING") as log:
|
||||||
|
result = conversation_agent([conversation_1, conversation_2], max_length=48)
|
||||||
|
self.assertEqual(len(log.output), 2)
|
||||||
|
self.assertIn("You might consider trimming the early phase of the conversation", log.output[0])
|
||||||
|
self.assertIn("Setting `pad_token_id`", log.output[1])
|
||||||
|
|
||||||
|
# Two conversations in one pass
|
||||||
|
self.assertEqual(result, [conversation_1, conversation_2])
|
||||||
|
self.assertEqual(
|
||||||
|
result,
|
||||||
|
[
|
||||||
|
Conversation(
|
||||||
|
None,
|
||||||
|
past_user_inputs=["Going to the movies tonight - any suggestions?"],
|
||||||
|
generated_responses=["L"],
|
||||||
|
),
|
||||||
|
Conversation(
|
||||||
|
None, past_user_inputs=["What's the last book you have read?"], generated_responses=["L"]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# One conversation with history
|
||||||
|
conversation_2.add_user_input("Why do you recommend it?")
|
||||||
|
with self.assertLogs("transformers", level="WARNING") as log:
|
||||||
|
result = conversation_agent(conversation_2, max_length=64)
|
||||||
|
self.assertEqual(len(log.output), 3)
|
||||||
|
self.assertIn("Cutting history off because it's too long", log.output[0])
|
||||||
|
self.assertIn("You might consider trimming the early phase of the conversation", log.output[1])
|
||||||
|
self.assertIn("Setting `pad_token_id`", log.output[2])
|
||||||
|
|
||||||
|
self.assertEqual(result, conversation_2)
|
||||||
|
self.assertEqual(
|
||||||
|
result,
|
||||||
|
Conversation(
|
||||||
|
None,
|
||||||
|
past_user_inputs=["What's the last book you have read?", "Why do you recommend it?"],
|
||||||
|
generated_responses=["L", "L"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_history_cache(self):
|
||||||
|
conversation_agent = self.get_pipeline()
|
||||||
|
conversation = Conversation(
|
||||||
|
"Why do you recommend it?",
|
||||||
|
past_user_inputs=["What's the last book you have read?"],
|
||||||
|
generated_responses=["b"],
|
||||||
|
)
|
||||||
|
with self.assertLogs("transformers", level="WARNING") as log:
|
||||||
|
_ = conversation_agent(conversation, max_length=60)
|
||||||
|
self.assertEqual(len(log.output), 3)
|
||||||
|
self.assertIn("Cutting history off because it's too long (63 > 28) for underlying model", log.output[0])
|
||||||
|
self.assertIn("63 is bigger than 0.9 * max_length: 60", log.output[1])
|
||||||
|
self.assertIn("Setting `pad_token_id`", log.output[2])
|
||||||
|
self.assertEqual(conversation._index, 1)
|
||||||
|
self.assertEqual(
|
||||||
|
conversation._history,
|
||||||
|
[
|
||||||
|
87,
|
||||||
|
104,
|
||||||
|
97,
|
||||||
|
116,
|
||||||
|
39,
|
||||||
|
115,
|
||||||
|
32,
|
||||||
|
116,
|
||||||
|
104,
|
||||||
|
101,
|
||||||
|
32,
|
||||||
|
108,
|
||||||
|
97,
|
||||||
|
115,
|
||||||
|
116,
|
||||||
|
32,
|
||||||
|
98,
|
||||||
|
111,
|
||||||
|
111,
|
||||||
|
107,
|
||||||
|
32,
|
||||||
|
121,
|
||||||
|
111,
|
||||||
|
117,
|
||||||
|
32,
|
||||||
|
104,
|
||||||
|
97,
|
||||||
|
118,
|
||||||
|
101,
|
||||||
|
32,
|
||||||
|
114,
|
||||||
|
101,
|
||||||
|
97,
|
||||||
|
100,
|
||||||
|
63,
|
||||||
|
259, # EOS
|
||||||
|
98, # b
|
||||||
|
259, # EOS
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||||
pipeline_task = "conversational"
|
pipeline_task = "conversational"
|
||||||
small_models = [] # Models tested without the @slow decorator
|
small_models = [] # Models tested without the @slow decorator
|
||||||
|
|
Loading…
Reference in New Issue