FIX / PEFT: Pass device correctly to peft (#30397)

pass device correctly to peft
This commit is contained in:
Younes Belkada 2024-04-22 18:13:19 +02:00 committed by GitHub
parent 13b3b90ab1
commit 367a0dbd53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 14 additions and 6 deletions

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import warnings import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from ..utils import ( from ..utils import (
check_peft_version, check_peft_version,
@ -25,6 +25,9 @@ from ..utils import (
) )
if is_torch_available():
import torch
if is_accelerate_available(): if is_accelerate_available():
from accelerate import dispatch_model from accelerate import dispatch_model
from accelerate.utils import get_balanced_memory, infer_auto_device_map from accelerate.utils import get_balanced_memory, infer_auto_device_map
@ -32,10 +35,6 @@ if is_accelerate_available():
# Minimum PEFT version supported for the integration # Minimum PEFT version supported for the integration
MIN_PEFT_VERSION = "0.5.0" MIN_PEFT_VERSION = "0.5.0"
if TYPE_CHECKING:
if is_torch_available():
import torch
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -151,6 +150,15 @@ class PeftAdapterMixin:
"You should either pass a `peft_model_id` or a `peft_config` and `adapter_state_dict` to load an adapter." "You should either pass a `peft_model_id` or a `peft_config` and `adapter_state_dict` to load an adapter."
) )
if "device" not in adapter_kwargs:
device = self.device if not hasattr(self, "hf_device_map") else list(self.hf_device_map.values())[0]
else:
device = adapter_kwargs.pop("device")
# To avoid PEFT errors later on with safetensors.
if isinstance(device, torch.device):
device = str(device)
# We keep `revision` in the signature for backward compatibility # We keep `revision` in the signature for backward compatibility
if revision is not None and "revision" not in adapter_kwargs: if revision is not None and "revision" not in adapter_kwargs:
adapter_kwargs["revision"] = revision adapter_kwargs["revision"] = revision
@ -190,7 +198,7 @@ class PeftAdapterMixin:
self._hf_peft_config_loaded = True self._hf_peft_config_loaded = True
if peft_model_id is not None: if peft_model_id is not None:
adapter_state_dict = load_peft_weights(peft_model_id, token=token, **adapter_kwargs) adapter_state_dict = load_peft_weights(peft_model_id, token=token, device=device, **adapter_kwargs)
# We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility # We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility
processed_adapter_state_dict = {} processed_adapter_state_dict = {}