203 lines
7.0 KiB
Python
203 lines
7.0 KiB
Python
# coding=utf-8
|
|
# Copyright 2020 The Hugging Face Team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import io
|
|
import unittest
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
|
|
from transformers import AlbertForMaskedLM
|
|
from transformers.testing_utils import require_torch
|
|
from transformers.utils import ModelOutput, is_torch_available
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_2
|
|
|
|
|
|
@dataclass
|
|
class ModelOutputTest(ModelOutput):
|
|
a: float
|
|
b: Optional[float] = None
|
|
c: Optional[float] = None
|
|
|
|
|
|
class ModelOutputTester(unittest.TestCase):
|
|
def test_get_attributes(self):
|
|
x = ModelOutputTest(a=30)
|
|
self.assertEqual(x.a, 30)
|
|
self.assertIsNone(x.b)
|
|
self.assertIsNone(x.c)
|
|
with self.assertRaises(AttributeError):
|
|
_ = x.d
|
|
|
|
def test_index_with_ints_and_slices(self):
|
|
x = ModelOutputTest(a=30, b=10)
|
|
self.assertEqual(x[0], 30)
|
|
self.assertEqual(x[1], 10)
|
|
self.assertEqual(x[:2], (30, 10))
|
|
self.assertEqual(x[:], (30, 10))
|
|
|
|
x = ModelOutputTest(a=30, c=10)
|
|
self.assertEqual(x[0], 30)
|
|
self.assertEqual(x[1], 10)
|
|
self.assertEqual(x[:2], (30, 10))
|
|
self.assertEqual(x[:], (30, 10))
|
|
|
|
def test_index_with_strings(self):
|
|
x = ModelOutputTest(a=30, b=10)
|
|
self.assertEqual(x["a"], 30)
|
|
self.assertEqual(x["b"], 10)
|
|
with self.assertRaises(KeyError):
|
|
_ = x["c"]
|
|
|
|
x = ModelOutputTest(a=30, c=10)
|
|
self.assertEqual(x["a"], 30)
|
|
self.assertEqual(x["c"], 10)
|
|
with self.assertRaises(KeyError):
|
|
_ = x["b"]
|
|
|
|
def test_dict_like_properties(self):
|
|
x = ModelOutputTest(a=30)
|
|
self.assertEqual(list(x.keys()), ["a"])
|
|
self.assertEqual(list(x.values()), [30])
|
|
self.assertEqual(list(x.items()), [("a", 30)])
|
|
self.assertEqual(list(x), ["a"])
|
|
|
|
x = ModelOutputTest(a=30, b=10)
|
|
self.assertEqual(list(x.keys()), ["a", "b"])
|
|
self.assertEqual(list(x.values()), [30, 10])
|
|
self.assertEqual(list(x.items()), [("a", 30), ("b", 10)])
|
|
self.assertEqual(list(x), ["a", "b"])
|
|
|
|
x = ModelOutputTest(a=30, c=10)
|
|
self.assertEqual(list(x.keys()), ["a", "c"])
|
|
self.assertEqual(list(x.values()), [30, 10])
|
|
self.assertEqual(list(x.items()), [("a", 30), ("c", 10)])
|
|
self.assertEqual(list(x), ["a", "c"])
|
|
|
|
with self.assertRaises(Exception):
|
|
x = x.update({"d": 20})
|
|
with self.assertRaises(Exception):
|
|
del x["a"]
|
|
with self.assertRaises(Exception):
|
|
_ = x.pop("a")
|
|
with self.assertRaises(Exception):
|
|
_ = x.setdefault("d", 32)
|
|
|
|
def test_set_attributes(self):
|
|
x = ModelOutputTest(a=30)
|
|
x.a = 10
|
|
self.assertEqual(x.a, 10)
|
|
self.assertEqual(x["a"], 10)
|
|
|
|
def test_set_keys(self):
|
|
x = ModelOutputTest(a=30)
|
|
x["a"] = 10
|
|
self.assertEqual(x.a, 10)
|
|
self.assertEqual(x["a"], 10)
|
|
|
|
def test_instantiate_from_dict(self):
|
|
x = ModelOutputTest({"a": 30, "b": 10})
|
|
self.assertEqual(list(x.keys()), ["a", "b"])
|
|
self.assertEqual(x.a, 30)
|
|
self.assertEqual(x.b, 10)
|
|
|
|
def test_instantiate_from_iterator(self):
|
|
x = ModelOutputTest([("a", 30), ("b", 10)])
|
|
self.assertEqual(list(x.keys()), ["a", "b"])
|
|
self.assertEqual(x.a, 30)
|
|
self.assertEqual(x.b, 10)
|
|
|
|
with self.assertRaises(ValueError):
|
|
_ = ModelOutputTest([("a", 30), (10, 10)])
|
|
|
|
x = ModelOutputTest(a=(30, 30))
|
|
self.assertEqual(list(x.keys()), ["a"])
|
|
self.assertEqual(x.a, (30, 30))
|
|
|
|
@require_torch
|
|
def test_torch_pytree(self):
|
|
# ensure torch.utils._pytree treats ModelOutput subclasses as nodes (and not leaves)
|
|
# this is important for DistributedDataParallel gradient synchronization with static_graph=True
|
|
import torch.utils._pytree as pytree
|
|
|
|
x = ModelOutput({"a": 1.0, "c": 2.0})
|
|
self.assertFalse(pytree._is_leaf(x))
|
|
|
|
x = ModelOutputTest(a=1.0, c=2.0)
|
|
self.assertFalse(pytree._is_leaf(x))
|
|
|
|
expected_flat_outs = [1.0, 2.0]
|
|
expected_tree_spec = pytree.TreeSpec(ModelOutputTest, ["a", "c"], [pytree.LeafSpec(), pytree.LeafSpec()])
|
|
|
|
actual_flat_outs, actual_tree_spec = pytree.tree_flatten(x)
|
|
self.assertEqual(expected_flat_outs, actual_flat_outs)
|
|
self.assertEqual(expected_tree_spec, actual_tree_spec)
|
|
|
|
unflattened_x = pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
|
|
self.assertEqual(x, unflattened_x)
|
|
|
|
if is_torch_greater_or_equal_than_2_2:
|
|
self.assertEqual(
|
|
pytree.treespec_dumps(actual_tree_spec),
|
|
'[1, {"type": "tests.utils.test_model_output.ModelOutputTest", "context": "[\\"a\\", \\"c\\"]", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]',
|
|
)
|
|
|
|
# TODO: @ydshieh
|
|
@unittest.skip("CPU OOM")
|
|
@require_torch
|
|
def test_export_serialization(self):
|
|
if not is_torch_greater_or_equal_than_2_2:
|
|
return
|
|
|
|
model_cls = AlbertForMaskedLM
|
|
model_config = model_cls.config_class()
|
|
model = model_cls(model_config)
|
|
|
|
input_dict = {"input_ids": torch.randint(0, 30000, (1, 512), dtype=torch.int64, requires_grad=False)}
|
|
|
|
ep = torch.export.export(model, (), input_dict)
|
|
|
|
buffer = io.BytesIO()
|
|
torch.export.save(ep, buffer)
|
|
buffer.seek(0)
|
|
loaded_ep = torch.export.load(buffer)
|
|
|
|
input_dict = {"input_ids": torch.randint(0, 30000, (1, 512), dtype=torch.int64, requires_grad=False)}
|
|
assert torch.allclose(model(**input_dict).logits, loaded_ep(**input_dict).logits)
|
|
|
|
|
|
class ModelOutputTestNoDataclass(ModelOutput):
|
|
"""Invalid test subclass of ModelOutput where @dataclass decorator is not used"""
|
|
|
|
a: float
|
|
b: Optional[float] = None
|
|
c: Optional[float] = None
|
|
|
|
|
|
class ModelOutputSubclassTester(unittest.TestCase):
|
|
def test_direct_model_output(self):
|
|
# Check that direct usage of ModelOutput instantiates without errors
|
|
ModelOutput({"a": 1.1})
|
|
|
|
def test_subclass_no_dataclass(self):
|
|
# Check that a subclass of ModelOutput without @dataclass is invalid
|
|
# A valid subclass is inherently tested other unit tests above.
|
|
with self.assertRaises(TypeError):
|
|
ModelOutputTestNoDataclass(a=1.1, b=2.2, c=3.3)
|