Make `ModelOutput` serializable (#26493)

* Make `ModelOutput` serializable

Original PR from diffusers : https://github.com/huggingface/diffusers/pull/5234

* Black
This commit is contained in:
Charles Bensimon 2023-10-05 11:08:44 +02:00 committed by GitHub
parent 54e17a15dc
commit 19f0b7dd02
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 0 deletions

View File

@ -416,6 +416,13 @@ class ModelOutput(OrderedDict):
# Don't call self.__setattr__ to avoid recursion errors
super().__setattr__(key, value)
def __reduce__(self):
if not is_dataclass(self):
return super().__reduce__()
callable, _args, *remaining = super().__reduce__()
args = tuple(getattr(self, field.name) for field in fields(self))
return callable, args, *remaining
def to_tuple(self) -> Tuple[Any]:
"""
Convert self to a tuple containing all the attributes/keys that are not `None`.