[BNB] Throw `ValueError` when trying to cast or assign (#20409)

* `bnb` ValueError when tries to cast or assign

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* remove docstrings

* change error log

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Younes Belkada 2022-11-23 15:51:50 +01:00 committed by GitHub
parent 03ae1f060b
commit ad654e4484
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 70 additions and 0 deletions

View File

@ -1666,6 +1666,36 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
mem = mem + mem_bufs
return mem
def to(self, *args, **kwargs):
# Checks if the model has been loaded in 8-bit
if getattr(self, "is_loaded_in_8bit", False):
raise ValueError(
"`.to` is not supported for `8-bit` models. Please use the model as it is, since the"
" model has already been set to the correct devices and casted to the correct `dtype`."
)
else:
return super().to(*args, **kwargs)
def half(self, *args):
# Checks if the model has been loaded in 8-bit
if getattr(self, "is_loaded_in_8bit", False):
raise ValueError(
"`.half()` is not supported for `8-bit` models. Please use the model as it is, since the"
" model has already been casted to the correct `dtype`."
)
else:
return super().half(*args)
def float(self, *args):
# Checks if the model has been loaded in 8-bit
if getattr(self, "is_loaded_in_8bit", False):
raise ValueError(
"`.float()` is not supported for `8-bit` models. Please use the model as it is, since the"
" model has already been casted to the correct `dtype`."
)
else:
return super().float(*args)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
r"""

View File

@ -115,6 +115,46 @@ class MixedInt8Test(BaseMixedInt8Test):
with self.assertWarns(UserWarning), tempfile.TemporaryDirectory() as tmpdirname:
self.model_8bit.save_pretrained(tmpdirname)
def test_device_and_dtype_assignment(self):
r"""
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
Checks also if other models are casted correctly.
"""
with self.assertRaises(ValueError):
# Tries with `str`
self.model_8bit.to("cpu")
with self.assertRaises(ValueError):
# Tries with a `dtype``
self.model_8bit.to(torch.float16)
with self.assertRaises(ValueError):
# Tries with a `device`
self.model_8bit.to(torch.device("cuda:0"))
with self.assertRaises(ValueError):
# Tries with a `device`
self.model_8bit.float()
with self.assertRaises(ValueError):
# Tries with a `device`
self.model_8bit.half()
# Test if we did not break anything
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
self.model_fp16 = self.model_fp16.to(torch.float32)
_ = self.model_fp16.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
# Check this does not throw an error
_ = self.model_fp16.to("cpu")
# Check this does not throw an error
_ = self.model_fp16.half()
# Check this does not throw an error
_ = self.model_fp16.float()
class MixedInt8ModelClassesTest(BaseMixedInt8Test):
def setUp(self):