Fix ConversationalPipeline tests (#26217)
Add BlenderbotSmall templates and correct handling for conversation.past_user_inputs
This commit is contained in:
parent
bc7ce1808f
commit
f0a6057fbc
|
@ -236,3 +236,18 @@ class BlenderbotSmallTokenizer(PreTrainedTokenizer):
|
|||
index += 1
|
||||
|
||||
return vocab_file, merge_file
|
||||
|
||||
@property
|
||||
# Copied from transformers.models.blenderbot.tokenization_blenderbot.BlenderbotTokenizer.default_chat_template
|
||||
def default_chat_template(self):
|
||||
"""
|
||||
A very simple chat template that just adds whitespace between messages.
|
||||
"""
|
||||
return (
|
||||
"{% for message in messages %}"
|
||||
"{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}"
|
||||
"{{ message['content'] }}"
|
||||
"{% if not loop.last %}{{ ' ' }}{% endif %}"
|
||||
"{% endfor %}"
|
||||
"{{ eos_token }}"
|
||||
)
|
||||
|
|
|
@ -117,3 +117,18 @@ class BlenderbotSmallTokenizerFast(PreTrainedTokenizerFast):
|
|||
if token_ids_1 is None:
|
||||
return len(cls + token_ids_0 + sep) * [0]
|
||||
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
|
||||
|
||||
@property
|
||||
# Copied from transformers.models.blenderbot.tokenization_blenderbot.BlenderbotTokenizer.default_chat_template
|
||||
def default_chat_template(self):
|
||||
"""
|
||||
A very simple chat template that just adds whitespace between messages.
|
||||
"""
|
||||
return (
|
||||
"{% for message in messages %}"
|
||||
"{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}"
|
||||
"{{ message['content'] }}"
|
||||
"{% if not loop.last %}{{ ' ' }}{% endif %}"
|
||||
"{% endfor %}"
|
||||
"{{ eos_token }}"
|
||||
)
|
||||
|
|
|
@ -140,8 +140,8 @@ class ConversationalPipelineTests(unittest.TestCase):
|
|||
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
|
||||
conversation_2 = Conversation("What's the last book you have read?")
|
||||
# Then
|
||||
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
||||
self.assertEqual(len(conversation_2.past_user_inputs), 0)
|
||||
self.assertEqual(len(conversation_1.past_user_inputs), 1)
|
||||
self.assertEqual(len(conversation_2.past_user_inputs), 1)
|
||||
# When
|
||||
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
|
||||
# Then
|
||||
|
@ -171,7 +171,7 @@ class ConversationalPipelineTests(unittest.TestCase):
|
|||
conversation_agent = pipeline(task="conversational", min_length_for_response=24, device=DEFAULT_DEVICE_NUM)
|
||||
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
|
||||
# Then
|
||||
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
||||
self.assertEqual(len(conversation_1.past_user_inputs), 1)
|
||||
# When
|
||||
result = conversation_agent(conversation_1, do_sample=False, max_length=36)
|
||||
# Then
|
||||
|
@ -379,8 +379,8 @@ These are just a few of the many attractions that Paris has to offer. With so mu
|
|||
conversation_1 = Conversation("My name is Sarah and I live in London")
|
||||
conversation_2 = Conversation("Going to the movies tonight, What movie would you recommend? ")
|
||||
# Then
|
||||
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
||||
self.assertEqual(len(conversation_2.past_user_inputs), 0)
|
||||
self.assertEqual(len(conversation_1.past_user_inputs), 1)
|
||||
self.assertEqual(len(conversation_2.past_user_inputs), 1)
|
||||
# When
|
||||
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
|
||||
# Then
|
||||
|
|
Loading…
Reference in New Issue