Let's not cast them all (#18471)
* add correct dtypes when checking for params dtype * forward contrib credits * Update src/transformers/modeling_utils.py Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com> * more comments - added more comments on why we cast only floating point parameters * Update src/transformers/modeling_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: sgugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com>
This commit is contained in:
parent
499450ed75
commit
ab62a23d8c
|
@ -543,8 +543,10 @@ def _load_state_dict_into_meta_model(
|
|||
param_name = param_name[len(start_prefix) :]
|
||||
|
||||
module_name = param_name
|
||||
# We convert floating dtypes to the `dtype` passed.
|
||||
if dtype is not None and not str(param.dtype).startswith("torch.int"):
|
||||
|
||||
# We convert floating dtypes to the `dtype` passed.We want to keep the buffers/params
|
||||
# in int/uint/bool and not cast them.
|
||||
if dtype is not None and torch.is_floating_point(param):
|
||||
param = param.to(dtype)
|
||||
|
||||
if device_map is None:
|
||||
|
|
Loading…
Reference in New Issue