Let's not cast them all (#18471)

* add correct dtypes when checking for params dtype

* forward contrib credits

* Update src/transformers/modeling_utils.py

Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com>

* more comments

- added more comments on why we cast only floating point parameters

* Update src/transformers/modeling_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: sgugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com>
This commit is contained in:
Younes Belkada 2022-08-08 23:48:49 +02:00 committed by GitHub
parent 499450ed75
commit ab62a23d8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 2 deletions

View File

@ -543,8 +543,10 @@ def _load_state_dict_into_meta_model(
param_name = param_name[len(start_prefix) :]
module_name = param_name
# We convert floating dtypes to the `dtype` passed.
if dtype is not None and not str(param.dtype).startswith("torch.int"):
# We convert floating dtypes to the `dtype` passed.We want to keep the buffers/params
# in int/uint/bool and not cast them.
if dtype is not None and torch.is_floating_point(param):
param = param.to(dtype)
if device_map is None: