fix `AutoModel.from_pretrained(..., torch_dtype=...)` (#13209)

* fix AutoModel.from_pretrained(..., torch_dtype=...)

* fix to_diff_dict

* add better test

* torch is not always available when a model has self.torch_dtype
This commit is contained in:
Stas Bekman 2021-08-24 02:43:41 -07:00 committed by GitHub
parent 39db2f3c19
commit 5c6eca71a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 46 additions and 4 deletions

View File

@ -30,6 +30,7 @@ from .file_utils import (
hf_bucket_url,
is_offline_mode,
is_remote_url,
is_torch_available,
)
from .utils import logging
@ -207,6 +208,9 @@ class PretrainedConfig(PushToHubMixin):
this attribute contains just the floating type string without the ``torch.`` prefix. For example, for
``torch.float16`` ``torch_dtype`` is the ``"float16"`` string.
This attribute is currently not being used during model loading time, but this may change in the future
versions. But we can already start preparing for the future by saving the dtype with save_pretrained.
TensorFlow specific parameters
- **use_bfloat16** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should use
@ -270,6 +274,14 @@ class PretrainedConfig(PushToHubMixin):
else:
self.num_labels = kwargs.pop("num_labels", 2)
if self.torch_dtype is not None and isinstance(self.torch_dtype, str):
# we will start using self.torch_dtype in v5, but to be consistent with
# from_pretrained's torch_dtype arg convert it to an actual torch.dtype object
if is_torch_available():
import torch
self.torch_dtype = getattr(torch, self.torch_dtype)
# Tokenizer arguments TODO: eventually tokenizer and models should share the same config
self.tokenizer_class = kwargs.pop("tokenizer_class", None)
self.prefix = kwargs.pop("prefix", None)
@ -574,7 +586,8 @@ class PretrainedConfig(PushToHubMixin):
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
to_remove.append(key)
if key != "torch_dtype":
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
@ -640,6 +653,8 @@ class PretrainedConfig(PushToHubMixin):
):
serializable_config_dict[key] = value
self.dict_torch_dtype_to_str(serializable_config_dict)
return serializable_config_dict
def to_dict(self) -> Dict[str, Any]:
@ -656,6 +671,8 @@ class PretrainedConfig(PushToHubMixin):
# Transformers version when serializing the model
output["transformers_version"] = __version__
self.dict_torch_dtype_to_str(output)
return output
def to_json_string(self, use_diff: bool = True) -> str:
@ -738,6 +755,15 @@ class PretrainedConfig(PushToHubMixin):
setattr(self, k, v)
def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
"""
Checks whether the passed dictionary has a `torch_dtype` key and if it's not None, converts torch.dtype to a
string of just the type. For example, :obj:`torch.float32` get converted into `"float32"` string, which can
then be stored in the json format.
"""
if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
PretrainedConfig.push_to_hub = copy_func(PretrainedConfig.push_to_hub)
PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format(

View File

@ -16,6 +16,7 @@
import copy
import gc
import inspect
import json
import os.path
import random
import tempfile
@ -1663,9 +1664,11 @@ class ModelUtilsTest(TestCasePlus):
@require_torch
def test_model_from_pretrained_torch_dtype(self):
# test that the model can be instantiated with dtype of either
# 1. config.torch_dtype setting in the saved model (priority)
# 2. via autodiscovery by looking at model weights
# 1. explicit from_pretrained's torch_dtype argument
# 2. via autodiscovery by looking at model weights (torch_dtype="auto")
# so if a model.half() was saved, we want it to be instantiated as such.
#
# test an explicit model class, but also AutoModel separately as the latter goes through a different code path
model_path = self.get_auto_remove_tmp_dir()
# baseline - we know TINY_T5 is fp32 model
@ -1688,13 +1691,26 @@ class ModelUtilsTest(TestCasePlus):
model = model.half()
model.save_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
self.assertEqual(model.config.torch_dtype, "float16") # tests `config.torch_dtype` saving
self.assertEqual(model.config.torch_dtype, torch.float16)
self.assertEqual(model.dtype, torch.float16)
# tests `config.torch_dtype` saving
with open(f"{model_path}/config.json") as f:
config_dict = json.load(f)
self.assertEqual(config_dict["torch_dtype"], "float16")
# test fp16 save_pretrained, loaded with the explicit fp16
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
self.assertEqual(model.dtype, torch.float16)
# test AutoModel separately as it goes through a different path
# test auto-detection
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="auto")
self.assertEqual(model.dtype, torch.float32)
# test forcing an explicit dtype
model = AutoModel.from_pretrained(TINY_T5, torch_dtype=torch.float16)
self.assertEqual(model.dtype, torch.float16)
@require_torch
@is_staging_test