fix `AutoModel.from_pretrained(..., torch_dtype=...)` (#13209)
* fix AutoModel.from_pretrained(..., torch_dtype=...) * fix to_diff_dict * add better test * torch is not always available when a model has self.torch_dtype
This commit is contained in:
parent
39db2f3c19
commit
5c6eca71a9
|
@ -30,6 +30,7 @@ from .file_utils import (
|
||||||
hf_bucket_url,
|
hf_bucket_url,
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
is_remote_url,
|
is_remote_url,
|
||||||
|
is_torch_available,
|
||||||
)
|
)
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
|
@ -207,6 +208,9 @@ class PretrainedConfig(PushToHubMixin):
|
||||||
this attribute contains just the floating type string without the ``torch.`` prefix. For example, for
|
this attribute contains just the floating type string without the ``torch.`` prefix. For example, for
|
||||||
``torch.float16`` ``torch_dtype`` is the ``"float16"`` string.
|
``torch.float16`` ``torch_dtype`` is the ``"float16"`` string.
|
||||||
|
|
||||||
|
This attribute is currently not being used during model loading time, but this may change in the future
|
||||||
|
versions. But we can already start preparing for the future by saving the dtype with save_pretrained.
|
||||||
|
|
||||||
TensorFlow specific parameters
|
TensorFlow specific parameters
|
||||||
|
|
||||||
- **use_bfloat16** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should use
|
- **use_bfloat16** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should use
|
||||||
|
@ -270,6 +274,14 @@ class PretrainedConfig(PushToHubMixin):
|
||||||
else:
|
else:
|
||||||
self.num_labels = kwargs.pop("num_labels", 2)
|
self.num_labels = kwargs.pop("num_labels", 2)
|
||||||
|
|
||||||
|
if self.torch_dtype is not None and isinstance(self.torch_dtype, str):
|
||||||
|
# we will start using self.torch_dtype in v5, but to be consistent with
|
||||||
|
# from_pretrained's torch_dtype arg convert it to an actual torch.dtype object
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
self.torch_dtype = getattr(torch, self.torch_dtype)
|
||||||
|
|
||||||
# Tokenizer arguments TODO: eventually tokenizer and models should share the same config
|
# Tokenizer arguments TODO: eventually tokenizer and models should share the same config
|
||||||
self.tokenizer_class = kwargs.pop("tokenizer_class", None)
|
self.tokenizer_class = kwargs.pop("tokenizer_class", None)
|
||||||
self.prefix = kwargs.pop("prefix", None)
|
self.prefix = kwargs.pop("prefix", None)
|
||||||
|
@ -574,7 +586,8 @@ class PretrainedConfig(PushToHubMixin):
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
if hasattr(config, key):
|
if hasattr(config, key):
|
||||||
setattr(config, key, value)
|
setattr(config, key, value)
|
||||||
to_remove.append(key)
|
if key != "torch_dtype":
|
||||||
|
to_remove.append(key)
|
||||||
for key in to_remove:
|
for key in to_remove:
|
||||||
kwargs.pop(key, None)
|
kwargs.pop(key, None)
|
||||||
|
|
||||||
|
@ -640,6 +653,8 @@ class PretrainedConfig(PushToHubMixin):
|
||||||
):
|
):
|
||||||
serializable_config_dict[key] = value
|
serializable_config_dict[key] = value
|
||||||
|
|
||||||
|
self.dict_torch_dtype_to_str(serializable_config_dict)
|
||||||
|
|
||||||
return serializable_config_dict
|
return serializable_config_dict
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
@ -656,6 +671,8 @@ class PretrainedConfig(PushToHubMixin):
|
||||||
# Transformers version when serializing the model
|
# Transformers version when serializing the model
|
||||||
output["transformers_version"] = __version__
|
output["transformers_version"] = __version__
|
||||||
|
|
||||||
|
self.dict_torch_dtype_to_str(output)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def to_json_string(self, use_diff: bool = True) -> str:
|
def to_json_string(self, use_diff: bool = True) -> str:
|
||||||
|
@ -738,6 +755,15 @@ class PretrainedConfig(PushToHubMixin):
|
||||||
|
|
||||||
setattr(self, k, v)
|
setattr(self, k, v)
|
||||||
|
|
||||||
|
def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Checks whether the passed dictionary has a `torch_dtype` key and if it's not None, converts torch.dtype to a
|
||||||
|
string of just the type. For example, :obj:`torch.float32` get converted into `"float32"` string, which can
|
||||||
|
then be stored in the json format.
|
||||||
|
"""
|
||||||
|
if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
|
||||||
|
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
|
||||||
|
|
||||||
|
|
||||||
PretrainedConfig.push_to_hub = copy_func(PretrainedConfig.push_to_hub)
|
PretrainedConfig.push_to_hub = copy_func(PretrainedConfig.push_to_hub)
|
||||||
PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format(
|
PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format(
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
import copy
|
import copy
|
||||||
import gc
|
import gc
|
||||||
import inspect
|
import inspect
|
||||||
|
import json
|
||||||
import os.path
|
import os.path
|
||||||
import random
|
import random
|
||||||
import tempfile
|
import tempfile
|
||||||
|
@ -1663,9 +1664,11 @@ class ModelUtilsTest(TestCasePlus):
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_model_from_pretrained_torch_dtype(self):
|
def test_model_from_pretrained_torch_dtype(self):
|
||||||
# test that the model can be instantiated with dtype of either
|
# test that the model can be instantiated with dtype of either
|
||||||
# 1. config.torch_dtype setting in the saved model (priority)
|
# 1. explicit from_pretrained's torch_dtype argument
|
||||||
# 2. via autodiscovery by looking at model weights
|
# 2. via autodiscovery by looking at model weights (torch_dtype="auto")
|
||||||
# so if a model.half() was saved, we want it to be instantiated as such.
|
# so if a model.half() was saved, we want it to be instantiated as such.
|
||||||
|
#
|
||||||
|
# test an explicit model class, but also AutoModel separately as the latter goes through a different code path
|
||||||
model_path = self.get_auto_remove_tmp_dir()
|
model_path = self.get_auto_remove_tmp_dir()
|
||||||
|
|
||||||
# baseline - we know TINY_T5 is fp32 model
|
# baseline - we know TINY_T5 is fp32 model
|
||||||
|
@ -1688,13 +1691,26 @@ class ModelUtilsTest(TestCasePlus):
|
||||||
model = model.half()
|
model = model.half()
|
||||||
model.save_pretrained(model_path)
|
model.save_pretrained(model_path)
|
||||||
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
|
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
|
||||||
self.assertEqual(model.config.torch_dtype, "float16") # tests `config.torch_dtype` saving
|
self.assertEqual(model.config.torch_dtype, torch.float16)
|
||||||
self.assertEqual(model.dtype, torch.float16)
|
self.assertEqual(model.dtype, torch.float16)
|
||||||
|
|
||||||
|
# tests `config.torch_dtype` saving
|
||||||
|
with open(f"{model_path}/config.json") as f:
|
||||||
|
config_dict = json.load(f)
|
||||||
|
self.assertEqual(config_dict["torch_dtype"], "float16")
|
||||||
|
|
||||||
# test fp16 save_pretrained, loaded with the explicit fp16
|
# test fp16 save_pretrained, loaded with the explicit fp16
|
||||||
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
|
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
|
||||||
self.assertEqual(model.dtype, torch.float16)
|
self.assertEqual(model.dtype, torch.float16)
|
||||||
|
|
||||||
|
# test AutoModel separately as it goes through a different path
|
||||||
|
# test auto-detection
|
||||||
|
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="auto")
|
||||||
|
self.assertEqual(model.dtype, torch.float32)
|
||||||
|
# test forcing an explicit dtype
|
||||||
|
model = AutoModel.from_pretrained(TINY_T5, torch_dtype=torch.float16)
|
||||||
|
self.assertEqual(model.dtype, torch.float16)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
|
|
Loading…
Reference in New Issue