Making Conversation possible to create directly a full conversation (#9434)

* Cleaning up conversation tests.

* Adding tests that don't require downloading models + conversation can be

fully created from static state.

* Making tests non flaky (by fixing generation length)

* Bumping isort version.

* Doc cleanup.

* Remove unused test in this PR.

* Torch import guard for TF.

* Missing torch guard.

* Small mistake in doc.

* Actual uses `_history` and `_index` cache.

+ remove dead enumerate
+ improve warning message.

* Update src/transformers/pipelines/conversational.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/pipelines/conversational.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/pipelines/conversational.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Adding comments and cleaner code to address history copy.

* Improving pipeline name in tests.

* Change tokenizer to a real one (still created at runtime with no

external dependency)

* Simplify DummyTok, reverse changes on tokenization.

* Removing DummyTok.

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Nicolas Patry 2021-01-08 14:33:25 +01:00 committed by GitHub
parent 4fbcf8ea49
commit 02e05fb0a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 227 additions and 19 deletions

View File

@ -33,6 +33,14 @@ class Conversation:
conversation_id (:obj:`uuid.UUID`, `optional`):
Unique identifier for the conversation. If not provided, a random UUID4 id will be assigned to the
conversation.
past_user_inputs (:obj:`List[str]`, `optional`):
Eventual past history of the conversation of the user. You don't need to pass it manually if you use the
pipeline interactively but if you want to recreate history you need to set both :obj:`past_user_inputs` and
:obj:`generated_responses` with equal length lists of strings
generated_responses (:obj:`List[str]`, `optional`):
Eventual past history of the conversation of the model. You don't need to pass it manually if you use the
pipeline interactively but if you want to recreate history you need to set both :obj:`past_user_inputs` and
:obj:`generated_responses` with equal length lists of strings
Usage::
@ -47,14 +55,33 @@ class Conversation:
conversation.add_user_input("Is it good?")
"""
def __init__(self, text: str = None, conversation_id: uuid.UUID = None):
def __init__(
self, text: str = None, conversation_id: uuid.UUID = None, past_user_inputs=None, generated_responses=None
):
if not conversation_id:
conversation_id = uuid.uuid4()
if past_user_inputs is None:
past_user_inputs = []
if generated_responses is None:
generated_responses = []
self.uuid: uuid.UUID = conversation_id
self.past_user_inputs: List[str] = []
self.generated_responses: List[str] = []
self.history: List[int] = []
self.past_user_inputs: List[str] = past_user_inputs
self.generated_responses: List[str] = generated_responses
self.new_user_input: Optional[str] = text
self._index: int = 0
self._history: List[int] = []
def __eq__(self, other):
if not isinstance(other, Conversation):
return False
if self.uuid == other.uuid:
return True
return (
self.new_user_input == other.new_user_input
and self.past_user_inputs == other.past_user_inputs
and self.generated_responses == other.generated_responses
)
def add_user_input(self, text: str, overwrite: bool = False):
"""
@ -100,16 +127,6 @@ class Conversation:
"""
self.generated_responses.append(response)
def set_history(self, history: List[int]):
"""
Updates the value of the history of the conversation. The history is represented by a list of :obj:`token_ids`.
The history is used by the model to generate responses based on the previous conversation turns.
Args:
history (:obj:`List[int]`): History of tokens provided and generated for this conversation.
"""
self.history = history
def __repr__(self):
"""
Generates a string representation of the conversation.
@ -167,12 +184,40 @@ class ConversationalPipeline(Pipeline):
super().__init__(*args, **kwargs)
# We need at least an eos_token
assert self.tokenizer.eos_token_id is not None, "DialoguePipeline tokenizer should have an EOS token set"
assert self.tokenizer.eos_token_id is not None, "ConversationalPipeline tokenizer should have an EOS token set"
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.min_length_for_response = min_length_for_response
def _get_history(self, conversation):
"""
Private function (subject to change) that simply tokenizes and concatenates past inputs. Also saves that
tokenization into the conversation state.
Args:
conversation (:class:`~transformers.Conversation`)
Returns:
:obj:`List[int]`: The list of tokens for the past input of that conversation.
"""
# Make a copy to prevent messing cache up if there's an error
# within this function
history = conversation._history.copy()
index = conversation._index
new_index = index
for i, (past_user_input, generated_response) in enumerate(
zip(conversation.past_user_inputs[index:], conversation.generated_responses[index:])
):
for el in (past_user_input, generated_response):
new_history = self._parse_and_tokenize([el])[0]
history.extend(new_history)
new_index = i + index + 1
conversation._index = new_index
conversation._history = history
# Hand back a copy to caller so they can't accidently modify our cache.
return history.copy()
def __call__(
self,
conversations: Union[Conversation, List[Conversation]],
@ -220,7 +265,7 @@ class ConversationalPipeline(Pipeline):
with self.device_placement():
inputs = self._parse_and_tokenize([conversation.new_user_input for conversation in conversations])
histories = [conversation.history for conversation in conversations]
histories = [self._get_history(conversation) for conversation in conversations]
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
inputs = self._concat_inputs_history(inputs, histories, max_length)
@ -266,7 +311,6 @@ class ConversationalPipeline(Pipeline):
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
)
conversation.set_history(history[conversation_index])
output.append(conversation)
if len(output) == 1:
return output[0]
@ -333,7 +377,9 @@ class ConversationalPipeline(Pipeline):
if cutoff_eos_index == 0 or cutoff_eos_index == len(new_input) - 1:
break
else:
new_input = new_input[cutoff_eos_index + 1 :]
logger.warning(
f"Cutting history off because it's too long ({len(new_input)} > {max_length - self.min_length_for_response}) for underlying model"
)
outputs.append(new_input)
padded_outputs = self.tokenizer.pad(
{"input_ids": outputs}, padding="longest", return_attention_mask=True, return_tensors=self.framework

View File

@ -14,15 +14,177 @@
import unittest
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Conversation, ConversationalPipeline, pipeline
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
Conversation,
ConversationalPipeline,
is_torch_available,
pipeline,
)
from transformers.testing_utils import require_torch, slow, torch_device
from .test_pipelines_common import MonoInputPipelineCommonMixin
if is_torch_available():
import torch
from transformers.models.gpt2 import GPT2Config, GPT2LMHeadModel
DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0
class SimpleConversationPipelineTests(unittest.TestCase):
def get_pipeline(self):
# When
config = GPT2Config(
vocab_size=263,
n_ctx=128,
max_length=128,
n_embd=64,
n_layer=1,
n_head=8,
bos_token_id=256,
eos_token_id=257,
)
model = GPT2LMHeadModel(config)
# Force model output to be L
V, D = model.lm_head.weight.shape
bias = torch.zeros(V, requires_grad=True)
bias[76] = 1
model.lm_head.bias = torch.nn.Parameter(bias)
# # Created with:
# import tempfile
# from tokenizers import Tokenizer, models
# from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
# vocab = [(chr(i), i) for i in range(256)]
# tokenizer = Tokenizer(models.Unigram(vocab))
# with tempfile.NamedTemporaryFile() as f:
# tokenizer.save(f.name)
# real_tokenizer = PreTrainedTokenizerFast(tokenizer_file=f.name, eos_token="<eos>", bos_token="<bos>")
# real_tokenizer._tokenizer.save("dummy.json")
# Special tokens are automatically added at load time.
tokenizer = AutoTokenizer.from_pretrained("Narsil/small_conversational_test")
conversation_agent = pipeline(
task="conversational", device=DEFAULT_DEVICE_NUM, model=model, tokenizer=tokenizer
)
return conversation_agent
@require_torch
def test_integration_torch_conversation(self):
conversation_agent = self.get_pipeline()
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
conversation_2 = Conversation("What's the last book you have read?")
self.assertEqual(len(conversation_1.past_user_inputs), 0)
self.assertEqual(len(conversation_2.past_user_inputs), 0)
with self.assertLogs("transformers", level="WARNING") as log:
result = conversation_agent([conversation_1, conversation_2], max_length=48)
self.assertEqual(len(log.output), 2)
self.assertIn("You might consider trimming the early phase of the conversation", log.output[0])
self.assertIn("Setting `pad_token_id`", log.output[1])
# Two conversations in one pass
self.assertEqual(result, [conversation_1, conversation_2])
self.assertEqual(
result,
[
Conversation(
None,
past_user_inputs=["Going to the movies tonight - any suggestions?"],
generated_responses=["L"],
),
Conversation(
None, past_user_inputs=["What's the last book you have read?"], generated_responses=["L"]
),
],
)
# One conversation with history
conversation_2.add_user_input("Why do you recommend it?")
with self.assertLogs("transformers", level="WARNING") as log:
result = conversation_agent(conversation_2, max_length=64)
self.assertEqual(len(log.output), 3)
self.assertIn("Cutting history off because it's too long", log.output[0])
self.assertIn("You might consider trimming the early phase of the conversation", log.output[1])
self.assertIn("Setting `pad_token_id`", log.output[2])
self.assertEqual(result, conversation_2)
self.assertEqual(
result,
Conversation(
None,
past_user_inputs=["What's the last book you have read?", "Why do you recommend it?"],
generated_responses=["L", "L"],
),
)
@require_torch
def test_history_cache(self):
conversation_agent = self.get_pipeline()
conversation = Conversation(
"Why do you recommend it?",
past_user_inputs=["What's the last book you have read?"],
generated_responses=["b"],
)
with self.assertLogs("transformers", level="WARNING") as log:
_ = conversation_agent(conversation, max_length=60)
self.assertEqual(len(log.output), 3)
self.assertIn("Cutting history off because it's too long (63 > 28) for underlying model", log.output[0])
self.assertIn("63 is bigger than 0.9 * max_length: 60", log.output[1])
self.assertIn("Setting `pad_token_id`", log.output[2])
self.assertEqual(conversation._index, 1)
self.assertEqual(
conversation._history,
[
87,
104,
97,
116,
39,
115,
32,
116,
104,
101,
32,
108,
97,
115,
116,
32,
98,
111,
111,
107,
32,
121,
111,
117,
32,
104,
97,
118,
101,
32,
114,
101,
97,
100,
63,
259, # EOS
98, # b
259, # EOS
],
)
class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
pipeline_task = "conversational"
small_models = [] # Models tested without the @slow decorator