[Pipeline] fix failing bloom `pipeline` test (#20778)

fix failing `pipeline` test
This commit is contained in:
Younes Belkada 2022-12-15 18:46:00 +01:00 committed by GitHub
parent b9b70b0e66
commit 8891193e83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -284,10 +284,10 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
],
)
# torch_dtype not necessary
# torch_dtype will be automatically set to float32 if not provided - check: https://github.com/huggingface/transformers/pull/20602
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto")
self.assertEqual(pipe.model.device, torch.device(0))
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16)
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.float32)
out = pipe("This is a test")
self.assertEqual(
out,