diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 33abc1dba1..220ceaa29e 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1666,6 +1666,36 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix mem = mem + mem_bufs return mem + def to(self, *args, **kwargs): + # Checks if the model has been loaded in 8-bit + if getattr(self, "is_loaded_in_8bit", False): + raise ValueError( + "`.to` is not supported for `8-bit` models. Please use the model as it is, since the" + " model has already been set to the correct devices and casted to the correct `dtype`." + ) + else: + return super().to(*args, **kwargs) + + def half(self, *args): + # Checks if the model has been loaded in 8-bit + if getattr(self, "is_loaded_in_8bit", False): + raise ValueError( + "`.half()` is not supported for `8-bit` models. Please use the model as it is, since the" + " model has already been casted to the correct `dtype`." + ) + else: + return super().half(*args) + + def float(self, *args): + # Checks if the model has been loaded in 8-bit + if getattr(self, "is_loaded_in_8bit", False): + raise ValueError( + "`.float()` is not supported for `8-bit` models. Please use the model as it is, since the" + " model has already been casted to the correct `dtype`." + ) + else: + return super().float(*args) + @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): r""" diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index 6e8e7b842a..67af67e1d5 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -115,6 +115,46 @@ class MixedInt8Test(BaseMixedInt8Test): with self.assertWarns(UserWarning), tempfile.TemporaryDirectory() as tmpdirname: self.model_8bit.save_pretrained(tmpdirname) + def test_device_and_dtype_assignment(self): + r""" + Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error. + Checks also if other models are casted correctly. + """ + with self.assertRaises(ValueError): + # Tries with `str` + self.model_8bit.to("cpu") + + with self.assertRaises(ValueError): + # Tries with a `dtype`` + self.model_8bit.to(torch.float16) + + with self.assertRaises(ValueError): + # Tries with a `device` + self.model_8bit.to(torch.device("cuda:0")) + + with self.assertRaises(ValueError): + # Tries with a `device` + self.model_8bit.float() + + with self.assertRaises(ValueError): + # Tries with a `device` + self.model_8bit.half() + + # Test if we did not break anything + encoded_input = self.tokenizer(self.input_text, return_tensors="pt") + + self.model_fp16 = self.model_fp16.to(torch.float32) + _ = self.model_fp16.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + + # Check this does not throw an error + _ = self.model_fp16.to("cpu") + + # Check this does not throw an error + _ = self.model_fp16.half() + + # Check this does not throw an error + _ = self.model_fp16.float() + class MixedInt8ModelClassesTest(BaseMixedInt8Test): def setUp(self):