From 8891193e83b14ebdcfea939fe4b0897177985048 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu, 15 Dec 2022 18:46:00 +0100 Subject: [PATCH] [Pipeline] fix failing bloom `pipeline` test (#20778) fix failing `pipeline` test --- tests/pipelines/test_pipelines_text_generation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index 92bda4f810..c0aee8b2db 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -284,10 +284,10 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM ], ) - # torch_dtype not necessary + # torch_dtype will be automatically set to float32 if not provided - check: https://github.com/huggingface/transformers/pull/20602 pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto") self.assertEqual(pipe.model.device, torch.device(0)) - self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16) + self.assertEqual(pipe.model.lm_head.weight.dtype, torch.float32) out = pipe("This is a test") self.assertEqual( out,