diff --git a/src/transformers/pipelines/conversational.py b/src/transformers/pipelines/conversational.py index d6f0e2517f..9dce94626d 100644 --- a/src/transformers/pipelines/conversational.py +++ b/src/transformers/pipelines/conversational.py @@ -33,6 +33,14 @@ class Conversation: conversation_id (:obj:`uuid.UUID`, `optional`): Unique identifier for the conversation. If not provided, a random UUID4 id will be assigned to the conversation. + past_user_inputs (:obj:`List[str]`, `optional`): + Eventual past history of the conversation of the user. You don't need to pass it manually if you use the + pipeline interactively but if you want to recreate history you need to set both :obj:`past_user_inputs` and + :obj:`generated_responses` with equal length lists of strings + generated_responses (:obj:`List[str]`, `optional`): + Eventual past history of the conversation of the model. You don't need to pass it manually if you use the + pipeline interactively but if you want to recreate history you need to set both :obj:`past_user_inputs` and + :obj:`generated_responses` with equal length lists of strings Usage:: @@ -47,14 +55,33 @@ class Conversation: conversation.add_user_input("Is it good?") """ - def __init__(self, text: str = None, conversation_id: uuid.UUID = None): + def __init__( + self, text: str = None, conversation_id: uuid.UUID = None, past_user_inputs=None, generated_responses=None + ): if not conversation_id: conversation_id = uuid.uuid4() + if past_user_inputs is None: + past_user_inputs = [] + if generated_responses is None: + generated_responses = [] + self.uuid: uuid.UUID = conversation_id - self.past_user_inputs: List[str] = [] - self.generated_responses: List[str] = [] - self.history: List[int] = [] + self.past_user_inputs: List[str] = past_user_inputs + self.generated_responses: List[str] = generated_responses self.new_user_input: Optional[str] = text + self._index: int = 0 + self._history: List[int] = [] + + def __eq__(self, other): + if not isinstance(other, Conversation): + return False + if self.uuid == other.uuid: + return True + return ( + self.new_user_input == other.new_user_input + and self.past_user_inputs == other.past_user_inputs + and self.generated_responses == other.generated_responses + ) def add_user_input(self, text: str, overwrite: bool = False): """ @@ -100,16 +127,6 @@ class Conversation: """ self.generated_responses.append(response) - def set_history(self, history: List[int]): - """ - Updates the value of the history of the conversation. The history is represented by a list of :obj:`token_ids`. - The history is used by the model to generate responses based on the previous conversation turns. - - Args: - history (:obj:`List[int]`): History of tokens provided and generated for this conversation. - """ - self.history = history - def __repr__(self): """ Generates a string representation of the conversation. @@ -167,12 +184,40 @@ class ConversationalPipeline(Pipeline): super().__init__(*args, **kwargs) # We need at least an eos_token - assert self.tokenizer.eos_token_id is not None, "DialoguePipeline tokenizer should have an EOS token set" + assert self.tokenizer.eos_token_id is not None, "ConversationalPipeline tokenizer should have an EOS token set" if self.tokenizer.pad_token_id is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.min_length_for_response = min_length_for_response + def _get_history(self, conversation): + """ + Private function (subject to change) that simply tokenizes and concatenates past inputs. Also saves that + tokenization into the conversation state. + + Args: + conversation (:class:`~transformers.Conversation`) + + Returns: + :obj:`List[int]`: The list of tokens for the past input of that conversation. + """ + # Make a copy to prevent messing cache up if there's an error + # within this function + history = conversation._history.copy() + index = conversation._index + new_index = index + for i, (past_user_input, generated_response) in enumerate( + zip(conversation.past_user_inputs[index:], conversation.generated_responses[index:]) + ): + for el in (past_user_input, generated_response): + new_history = self._parse_and_tokenize([el])[0] + history.extend(new_history) + new_index = i + index + 1 + conversation._index = new_index + conversation._history = history + # Hand back a copy to caller so they can't accidently modify our cache. + return history.copy() + def __call__( self, conversations: Union[Conversation, List[Conversation]], @@ -220,7 +265,7 @@ class ConversationalPipeline(Pipeline): with self.device_placement(): inputs = self._parse_and_tokenize([conversation.new_user_input for conversation in conversations]) - histories = [conversation.history for conversation in conversations] + histories = [self._get_history(conversation) for conversation in conversations] max_length = generate_kwargs.get("max_length", self.model.config.max_length) inputs = self._concat_inputs_history(inputs, histories, max_length) @@ -266,7 +311,6 @@ class ConversationalPipeline(Pipeline): clean_up_tokenization_spaces=clean_up_tokenization_spaces, ) ) - conversation.set_history(history[conversation_index]) output.append(conversation) if len(output) == 1: return output[0] @@ -333,7 +377,9 @@ class ConversationalPipeline(Pipeline): if cutoff_eos_index == 0 or cutoff_eos_index == len(new_input) - 1: break else: - new_input = new_input[cutoff_eos_index + 1 :] + logger.warning( + f"Cutting history off because it's too long ({len(new_input)} > {max_length - self.min_length_for_response}) for underlying model" + ) outputs.append(new_input) padded_outputs = self.tokenizer.pad( {"input_ids": outputs}, padding="longest", return_attention_mask=True, return_tensors=self.framework diff --git a/tests/test_pipelines_conversational.py b/tests/test_pipelines_conversational.py index 16703b1113..ad00d92b3b 100644 --- a/tests/test_pipelines_conversational.py +++ b/tests/test_pipelines_conversational.py @@ -14,15 +14,177 @@ import unittest -from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Conversation, ConversationalPipeline, pipeline +from transformers import ( + AutoModelForSeq2SeqLM, + AutoTokenizer, + Conversation, + ConversationalPipeline, + is_torch_available, + pipeline, +) from transformers.testing_utils import require_torch, slow, torch_device from .test_pipelines_common import MonoInputPipelineCommonMixin +if is_torch_available(): + import torch + + from transformers.models.gpt2 import GPT2Config, GPT2LMHeadModel + DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0 +class SimpleConversationPipelineTests(unittest.TestCase): + def get_pipeline(self): + # When + config = GPT2Config( + vocab_size=263, + n_ctx=128, + max_length=128, + n_embd=64, + n_layer=1, + n_head=8, + bos_token_id=256, + eos_token_id=257, + ) + model = GPT2LMHeadModel(config) + # Force model output to be L + V, D = model.lm_head.weight.shape + bias = torch.zeros(V, requires_grad=True) + bias[76] = 1 + + model.lm_head.bias = torch.nn.Parameter(bias) + + # # Created with: + # import tempfile + + # from tokenizers import Tokenizer, models + # from transformers.tokenization_utils_fast import PreTrainedTokenizerFast + + # vocab = [(chr(i), i) for i in range(256)] + # tokenizer = Tokenizer(models.Unigram(vocab)) + # with tempfile.NamedTemporaryFile() as f: + # tokenizer.save(f.name) + # real_tokenizer = PreTrainedTokenizerFast(tokenizer_file=f.name, eos_token="", bos_token="") + + # real_tokenizer._tokenizer.save("dummy.json") + # Special tokens are automatically added at load time. + tokenizer = AutoTokenizer.from_pretrained("Narsil/small_conversational_test") + conversation_agent = pipeline( + task="conversational", device=DEFAULT_DEVICE_NUM, model=model, tokenizer=tokenizer + ) + return conversation_agent + + @require_torch + def test_integration_torch_conversation(self): + conversation_agent = self.get_pipeline() + conversation_1 = Conversation("Going to the movies tonight - any suggestions?") + conversation_2 = Conversation("What's the last book you have read?") + self.assertEqual(len(conversation_1.past_user_inputs), 0) + self.assertEqual(len(conversation_2.past_user_inputs), 0) + + with self.assertLogs("transformers", level="WARNING") as log: + result = conversation_agent([conversation_1, conversation_2], max_length=48) + self.assertEqual(len(log.output), 2) + self.assertIn("You might consider trimming the early phase of the conversation", log.output[0]) + self.assertIn("Setting `pad_token_id`", log.output[1]) + + # Two conversations in one pass + self.assertEqual(result, [conversation_1, conversation_2]) + self.assertEqual( + result, + [ + Conversation( + None, + past_user_inputs=["Going to the movies tonight - any suggestions?"], + generated_responses=["L"], + ), + Conversation( + None, past_user_inputs=["What's the last book you have read?"], generated_responses=["L"] + ), + ], + ) + + # One conversation with history + conversation_2.add_user_input("Why do you recommend it?") + with self.assertLogs("transformers", level="WARNING") as log: + result = conversation_agent(conversation_2, max_length=64) + self.assertEqual(len(log.output), 3) + self.assertIn("Cutting history off because it's too long", log.output[0]) + self.assertIn("You might consider trimming the early phase of the conversation", log.output[1]) + self.assertIn("Setting `pad_token_id`", log.output[2]) + + self.assertEqual(result, conversation_2) + self.assertEqual( + result, + Conversation( + None, + past_user_inputs=["What's the last book you have read?", "Why do you recommend it?"], + generated_responses=["L", "L"], + ), + ) + + @require_torch + def test_history_cache(self): + conversation_agent = self.get_pipeline() + conversation = Conversation( + "Why do you recommend it?", + past_user_inputs=["What's the last book you have read?"], + generated_responses=["b"], + ) + with self.assertLogs("transformers", level="WARNING") as log: + _ = conversation_agent(conversation, max_length=60) + self.assertEqual(len(log.output), 3) + self.assertIn("Cutting history off because it's too long (63 > 28) for underlying model", log.output[0]) + self.assertIn("63 is bigger than 0.9 * max_length: 60", log.output[1]) + self.assertIn("Setting `pad_token_id`", log.output[2]) + self.assertEqual(conversation._index, 1) + self.assertEqual( + conversation._history, + [ + 87, + 104, + 97, + 116, + 39, + 115, + 32, + 116, + 104, + 101, + 32, + 108, + 97, + 115, + 116, + 32, + 98, + 111, + 111, + 107, + 32, + 121, + 111, + 117, + 32, + 104, + 97, + 118, + 101, + 32, + 114, + 101, + 97, + 100, + 63, + 259, # EOS + 98, # b + 259, # EOS + ], + ) + + class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): pipeline_task = "conversational" small_models = [] # Models tested without the @slow decorator