Fix ConversationalPipeline tests (#26217)

Add BlenderbotSmall templates and correct handling for conversation.past_user_inputs
This commit is contained in:
Matt 2023-09-18 15:08:56 +01:00 committed by GitHub
parent bc7ce1808f
commit f0a6057fbc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 5 deletions

View File

@ -236,3 +236,18 @@ class BlenderbotSmallTokenizer(PreTrainedTokenizer):
index += 1
return vocab_file, merge_file
@property
# Copied from transformers.models.blenderbot.tokenization_blenderbot.BlenderbotTokenizer.default_chat_template
def default_chat_template(self):
"""
A very simple chat template that just adds whitespace between messages.
"""
return (
"{% for message in messages %}"
"{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}"
"{{ message['content'] }}"
"{% if not loop.last %}{{ ' ' }}{% endif %}"
"{% endfor %}"
"{{ eos_token }}"
)

View File

@ -117,3 +117,18 @@ class BlenderbotSmallTokenizerFast(PreTrainedTokenizerFast):
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
@property
# Copied from transformers.models.blenderbot.tokenization_blenderbot.BlenderbotTokenizer.default_chat_template
def default_chat_template(self):
"""
A very simple chat template that just adds whitespace between messages.
"""
return (
"{% for message in messages %}"
"{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}"
"{{ message['content'] }}"
"{% if not loop.last %}{{ ' ' }}{% endif %}"
"{% endfor %}"
"{{ eos_token }}"
)

View File

@ -140,8 +140,8 @@ class ConversationalPipelineTests(unittest.TestCase):
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
conversation_2 = Conversation("What's the last book you have read?")
# Then
self.assertEqual(len(conversation_1.past_user_inputs), 0)
self.assertEqual(len(conversation_2.past_user_inputs), 0)
self.assertEqual(len(conversation_1.past_user_inputs), 1)
self.assertEqual(len(conversation_2.past_user_inputs), 1)
# When
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
# Then
@ -171,7 +171,7 @@ class ConversationalPipelineTests(unittest.TestCase):
conversation_agent = pipeline(task="conversational", min_length_for_response=24, device=DEFAULT_DEVICE_NUM)
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
# Then
self.assertEqual(len(conversation_1.past_user_inputs), 0)
self.assertEqual(len(conversation_1.past_user_inputs), 1)
# When
result = conversation_agent(conversation_1, do_sample=False, max_length=36)
# Then
@ -379,8 +379,8 @@ These are just a few of the many attractions that Paris has to offer. With so mu
conversation_1 = Conversation("My name is Sarah and I live in London")
conversation_2 = Conversation("Going to the movies tonight, What movie would you recommend? ")
# Then
self.assertEqual(len(conversation_1.past_user_inputs), 0)
self.assertEqual(len(conversation_2.past_user_inputs), 0)
self.assertEqual(len(conversation_1.past_user_inputs), 1)
self.assertEqual(len(conversation_2.past_user_inputs), 1)
# When
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
# Then