Don't check the device when device_map=auto (#28351)
When running the case on multi-cards server with devcie_map-auto, It will not always be allocated to device 0, Because other processes may be using these cards. It will select the devices that can accommodate this model. Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
5d36025ca1
commit
03b980990a
|
@ -276,7 +276,6 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
|||
model="hf-internal-testing/tiny-random-bloom",
|
||||
model_kwargs={"device_map": "auto", "torch_dtype": torch.bfloat16},
|
||||
)
|
||||
self.assertEqual(pipe.model.device, torch.device(0))
|
||||
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16)
|
||||
out = pipe("This is a test")
|
||||
self.assertEqual(
|
||||
|
@ -293,7 +292,6 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
|||
|
||||
# Upgraded those two to real pipeline arguments (they just get sent for the model as they're unlikely to mean anything else.)
|
||||
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto", torch_dtype=torch.bfloat16)
|
||||
self.assertEqual(pipe.model.device, torch.device(0))
|
||||
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16)
|
||||
out = pipe("This is a test")
|
||||
self.assertEqual(
|
||||
|
@ -310,7 +308,6 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
|||
|
||||
# torch_dtype will be automatically set to float32 if not provided - check: https://github.com/huggingface/transformers/pull/20602
|
||||
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto")
|
||||
self.assertEqual(pipe.model.device, torch.device(0))
|
||||
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.float32)
|
||||
out = pipe("This is a test")
|
||||
self.assertEqual(
|
||||
|
|
Loading…
Reference in New Issue