FIX / PEFT: Pass device correctly to peft (#30397)
pass device correctly to peft
This commit is contained in:
parent
13b3b90ab1
commit
367a0dbd53
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from ..utils import (
|
||||
check_peft_version,
|
||||
|
@ -25,6 +25,9 @@ from ..utils import (
|
|||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import dispatch_model
|
||||
from accelerate.utils import get_balanced_memory, infer_auto_device_map
|
||||
|
@ -32,10 +35,6 @@ if is_accelerate_available():
|
|||
# Minimum PEFT version supported for the integration
|
||||
MIN_PEFT_VERSION = "0.5.0"
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
@ -151,6 +150,15 @@ class PeftAdapterMixin:
|
|||
"You should either pass a `peft_model_id` or a `peft_config` and `adapter_state_dict` to load an adapter."
|
||||
)
|
||||
|
||||
if "device" not in adapter_kwargs:
|
||||
device = self.device if not hasattr(self, "hf_device_map") else list(self.hf_device_map.values())[0]
|
||||
else:
|
||||
device = adapter_kwargs.pop("device")
|
||||
|
||||
# To avoid PEFT errors later on with safetensors.
|
||||
if isinstance(device, torch.device):
|
||||
device = str(device)
|
||||
|
||||
# We keep `revision` in the signature for backward compatibility
|
||||
if revision is not None and "revision" not in adapter_kwargs:
|
||||
adapter_kwargs["revision"] = revision
|
||||
|
@ -190,7 +198,7 @@ class PeftAdapterMixin:
|
|||
self._hf_peft_config_loaded = True
|
||||
|
||||
if peft_model_id is not None:
|
||||
adapter_state_dict = load_peft_weights(peft_model_id, token=token, **adapter_kwargs)
|
||||
adapter_state_dict = load_peft_weights(peft_model_id, token=token, device=device, **adapter_kwargs)
|
||||
|
||||
# We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility
|
||||
processed_adapter_state_dict = {}
|
||||
|
|
Loading…
Reference in New Issue