[BNB] Throw `ValueError` when trying to cast or assign (#20409)
* `bnb` ValueError when tries to cast or assign * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * remove docstrings * change error log Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
03ae1f060b
commit
ad654e4484
|
@ -1666,6 +1666,36 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
mem = mem + mem_bufs
|
||||
return mem
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
# Checks if the model has been loaded in 8-bit
|
||||
if getattr(self, "is_loaded_in_8bit", False):
|
||||
raise ValueError(
|
||||
"`.to` is not supported for `8-bit` models. Please use the model as it is, since the"
|
||||
" model has already been set to the correct devices and casted to the correct `dtype`."
|
||||
)
|
||||
else:
|
||||
return super().to(*args, **kwargs)
|
||||
|
||||
def half(self, *args):
|
||||
# Checks if the model has been loaded in 8-bit
|
||||
if getattr(self, "is_loaded_in_8bit", False):
|
||||
raise ValueError(
|
||||
"`.half()` is not supported for `8-bit` models. Please use the model as it is, since the"
|
||||
" model has already been casted to the correct `dtype`."
|
||||
)
|
||||
else:
|
||||
return super().half(*args)
|
||||
|
||||
def float(self, *args):
|
||||
# Checks if the model has been loaded in 8-bit
|
||||
if getattr(self, "is_loaded_in_8bit", False):
|
||||
raise ValueError(
|
||||
"`.float()` is not supported for `8-bit` models. Please use the model as it is, since the"
|
||||
" model has already been casted to the correct `dtype`."
|
||||
)
|
||||
else:
|
||||
return super().float(*args)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
|
||||
r"""
|
||||
|
|
|
@ -115,6 +115,46 @@ class MixedInt8Test(BaseMixedInt8Test):
|
|||
with self.assertWarns(UserWarning), tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.model_8bit.save_pretrained(tmpdirname)
|
||||
|
||||
def test_device_and_dtype_assignment(self):
|
||||
r"""
|
||||
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
|
||||
Checks also if other models are casted correctly.
|
||||
"""
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with `str`
|
||||
self.model_8bit.to("cpu")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `dtype``
|
||||
self.model_8bit.to(torch.float16)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `device`
|
||||
self.model_8bit.to(torch.device("cuda:0"))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `device`
|
||||
self.model_8bit.float()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `device`
|
||||
self.model_8bit.half()
|
||||
|
||||
# Test if we did not break anything
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
|
||||
self.model_fp16 = self.model_fp16.to(torch.float32)
|
||||
_ = self.model_fp16.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
|
||||
|
||||
# Check this does not throw an error
|
||||
_ = self.model_fp16.to("cpu")
|
||||
|
||||
# Check this does not throw an error
|
||||
_ = self.model_fp16.half()
|
||||
|
||||
# Check this does not throw an error
|
||||
_ = self.model_fp16.float()
|
||||
|
||||
|
||||
class MixedInt8ModelClassesTest(BaseMixedInt8Test):
|
||||
def setUp(self):
|
||||
|
|
Loading…
Reference in New Issue