diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 8bce35f9e3..78c012ec09 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -106,12 +106,13 @@ def no_init_weights(_enable=True): TODO(Patrick): Delete safety argument `_enable=True` at next major version. . """ global _init_weights + old_init_weights = _init_weights if _enable: _init_weights = False try: yield finally: - _init_weights = True + _init_weights = old_init_weights try: