FIX / PEFT: Pass device correctly to peft (#30397)
pass device correctly to peft
This commit is contained in:
parent
13b3b90ab1
commit
367a0dbd53
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import inspect
|
import inspect
|
||||||
import warnings
|
import warnings
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
check_peft_version,
|
check_peft_version,
|
||||||
|
@ -25,6 +25,9 @@ from ..utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
from accelerate import dispatch_model
|
from accelerate import dispatch_model
|
||||||
from accelerate.utils import get_balanced_memory, infer_auto_device_map
|
from accelerate.utils import get_balanced_memory, infer_auto_device_map
|
||||||
|
@ -32,10 +35,6 @@ if is_accelerate_available():
|
||||||
# Minimum PEFT version supported for the integration
|
# Minimum PEFT version supported for the integration
|
||||||
MIN_PEFT_VERSION = "0.5.0"
|
MIN_PEFT_VERSION = "0.5.0"
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
if is_torch_available():
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
@ -151,6 +150,15 @@ class PeftAdapterMixin:
|
||||||
"You should either pass a `peft_model_id` or a `peft_config` and `adapter_state_dict` to load an adapter."
|
"You should either pass a `peft_model_id` or a `peft_config` and `adapter_state_dict` to load an adapter."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "device" not in adapter_kwargs:
|
||||||
|
device = self.device if not hasattr(self, "hf_device_map") else list(self.hf_device_map.values())[0]
|
||||||
|
else:
|
||||||
|
device = adapter_kwargs.pop("device")
|
||||||
|
|
||||||
|
# To avoid PEFT errors later on with safetensors.
|
||||||
|
if isinstance(device, torch.device):
|
||||||
|
device = str(device)
|
||||||
|
|
||||||
# We keep `revision` in the signature for backward compatibility
|
# We keep `revision` in the signature for backward compatibility
|
||||||
if revision is not None and "revision" not in adapter_kwargs:
|
if revision is not None and "revision" not in adapter_kwargs:
|
||||||
adapter_kwargs["revision"] = revision
|
adapter_kwargs["revision"] = revision
|
||||||
|
@ -190,7 +198,7 @@ class PeftAdapterMixin:
|
||||||
self._hf_peft_config_loaded = True
|
self._hf_peft_config_loaded = True
|
||||||
|
|
||||||
if peft_model_id is not None:
|
if peft_model_id is not None:
|
||||||
adapter_state_dict = load_peft_weights(peft_model_id, token=token, **adapter_kwargs)
|
adapter_state_dict = load_peft_weights(peft_model_id, token=token, device=device, **adapter_kwargs)
|
||||||
|
|
||||||
# We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility
|
# We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility
|
||||||
processed_adapter_state_dict = {}
|
processed_adapter_state_dict = {}
|
||||||
|
|
Loading…
Reference in New Issue