Add balanced strategies for device_map in from_pretrained (#18349)

* Add balanced strategies for device_map in from_pretrained

* Add safeguards for Accelerate version

* Update src/transformers/modeling_utils.py

Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>

* Style

Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
Sylvain Gugger 2022-08-01 10:28:26 -04:00 committed by GitHub
parent 39e76d76fd
commit e0bc4c73e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 26 additions and 3 deletions

View File

@ -76,6 +76,7 @@ from .utils.versions import require_version_core
if is_accelerate_available():
from accelerate import __version__ as accelerate_version
from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
from accelerate.utils import (
load_offloaded_weights,
@ -84,6 +85,11 @@ if is_accelerate_available():
set_module_tensor_to_device,
)
if version.parse(accelerate_version) > version.parse("0.11.0"):
from accelerate.utils import get_balanced_memory
else:
get_balanced_memory = None
logger = logging.get_logger(__name__)
@ -1697,7 +1703,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
same device.
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`.
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
more information about each option see [designing a device
map](https://hf.co/docs/accelerate/main/big_modeling#designing-a-device-map).
max_memory (`Dict`, *optional*):
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
GPU and the available CPU RAM if unset.
@ -2105,10 +2113,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
with ContextManagers(init_contexts):
model = cls(config, *model_args, **model_kwargs)
if device_map == "auto":
if isinstance(device_map, str):
if model._no_split_modules is None:
raise ValueError(f"{model.__class__.__name__} does not support `device_map='auto'` yet.")
raise ValueError(f"{model.__class__.__name__} does not support `device_map='{device_map}'` yet.")
no_split_modules = model._no_split_modules
if device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
raise ValueError(
"If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or "
"'sequential'."
)
elif device_map in ["balanced", "balanced_low_0"] and get_balanced_memory is None:
raise ValueError(f"`device_map={device_map}` requires a source install of Accelerate.")
if device_map != "sequential" and get_balanced_memory is not None:
max_memory = get_balanced_memory(
model,
max_memory=max_memory,
no_split_module_classes=no_split_modules,
dtype=torch_dtype,
low_zero=(device_map == "balanced_low_0"),
)
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(