[PEFT] Allow PEFT model dict to be loaded (#25721)

* Allow PEFT model dict to be loaded

* make style

* make style

* Apply suggestions from code review

* address comments

* fixup

* final change

* added tests

* fix test

* better logic for handling if adapter has been loaded

* Update tests/peft_integration/test_peft_integration.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Patrick von Platen 2023-09-15 18:22:01 +02:00 committed by GitHub
parent 8b13471494
commit 0a55d9f737
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 76 additions and 22 deletions

View File

@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect import inspect
from typing import Optional from typing import TYPE_CHECKING, Any, Dict, Optional
from ..utils import ( from ..utils import (
check_peft_version, check_peft_version,
find_adapter_config_file, find_adapter_config_file,
is_accelerate_available, is_accelerate_available,
is_peft_available, is_peft_available,
is_torch_available,
logging, logging,
) )
@ -30,6 +31,11 @@ if is_accelerate_available():
# Minimum PEFT version supported for the integration # Minimum PEFT version supported for the integration
MIN_PEFT_VERSION = "0.5.0" MIN_PEFT_VERSION = "0.5.0"
if TYPE_CHECKING:
if is_torch_available():
import torch
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -61,7 +67,7 @@ class PeftAdapterMixin:
def load_adapter( def load_adapter(
self, self,
peft_model_id: str, peft_model_id: Optional[str] = None,
adapter_name: Optional[str] = None, adapter_name: Optional[str] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
token: Optional[str] = None, token: Optional[str] = None,
@ -69,6 +75,8 @@ class PeftAdapterMixin:
max_memory: Optional[str] = None, max_memory: Optional[str] = None,
offload_folder: Optional[str] = None, offload_folder: Optional[str] = None,
offload_index: Optional[int] = None, offload_index: Optional[int] = None,
peft_config: Dict[str, Any] = None,
adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None,
) -> None: ) -> None:
""" """
Load adapter weights from file or remote Hub folder. If you are not familiar with adapters and PEFT methods, we Load adapter weights from file or remote Hub folder. If you are not familiar with adapters and PEFT methods, we
@ -77,7 +85,7 @@ class PeftAdapterMixin:
Requires peft as a backend to load the adapter weights. Requires peft as a backend to load the adapter weights.
Args: Args:
peft_model_id (`str`): peft_model_id (`str`, *optional*):
The identifier of the model to look for on the Hub, or a local path to the saved adapter config file The identifier of the model to look for on the Hub, or a local path to the saved adapter config file
and adapter weights. and adapter weights.
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
@ -114,6 +122,12 @@ class PeftAdapterMixin:
If the `device_map` contains any value `"disk"`, the folder where we will offload weights. If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
offload_index (`int`, `optional`): offload_index (`int`, `optional`):
`offload_index` argument to be passed to `accelerate.dispatch_model` method. `offload_index` argument to be passed to `accelerate.dispatch_model` method.
peft_config (`Dict[str, Any]`, *optional*):
The configuration of the adapter to add, supported adapters are non-prefix tuning and adaption prompts
methods. This argument is used in case users directly pass PEFT state dicts
adapter_state_dict (`Dict[str, torch.Tensor]`, *optional*):
The state dict of the adapter to load. This argument is used in case users directly pass PEFT state
dicts
""" """
check_peft_version(min_version=MIN_PEFT_VERSION) check_peft_version(min_version=MIN_PEFT_VERSION)
@ -122,33 +136,41 @@ class PeftAdapterMixin:
from peft import PeftConfig, inject_adapter_in_model, load_peft_weights from peft import PeftConfig, inject_adapter_in_model, load_peft_weights
from peft.utils import set_peft_model_state_dict from peft.utils import set_peft_model_state_dict
if not self._hf_peft_config_loaded: if self._hf_peft_config_loaded and adapter_name in self.peft_config:
self._hf_peft_config_loaded = True
elif adapter_name in self.peft_config:
raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.") raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
adapter_config_file = find_adapter_config_file( if peft_model_id is None and (adapter_state_dict is None and peft_config is None):
peft_model_id,
revision=revision,
token=token,
)
if adapter_config_file is None:
raise ValueError( raise ValueError(
f"adapter model file not found in {peft_model_id}. Make sure you are passing the correct path to the " "You should either pass a `peft_model_id` or a `peft_config` and `adapter_state_dict` to load an adapter."
"adapter model."
) )
loaded_peft_config = PeftConfig.from_pretrained( if peft_config is None:
peft_model_id, adapter_config_file = find_adapter_config_file(
revision=revision, peft_model_id,
use_auth_token=token, revision=revision,
) token=token,
)
if adapter_config_file is None:
raise ValueError(
f"adapter model file not found in {peft_model_id}. Make sure you are passing the correct path to the "
"adapter model."
)
peft_config = PeftConfig.from_pretrained(
peft_model_id,
revision=revision,
use_auth_token=token,
)
# Create and add fresh new adapters into the model. # Create and add fresh new adapters into the model.
inject_adapter_in_model(loaded_peft_config, self, adapter_name) inject_adapter_in_model(peft_config, self, adapter_name)
adapter_state_dict = load_peft_weights(peft_model_id, revision=revision, use_auth_token=token) if not self._hf_peft_config_loaded:
self._hf_peft_config_loaded = True
if peft_model_id is not None:
adapter_state_dict = load_peft_weights(peft_model_id, revision=revision, use_auth_token=token)
# We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility # We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility
processed_adapter_state_dict = {} processed_adapter_state_dict = {}

View File

@ -16,6 +16,8 @@ import os
import tempfile import tempfile
import unittest import unittest
from huggingface_hub import hf_hub_download
from transformers import AutoModelForCausalLM, OPTForCausalLM from transformers import AutoModelForCausalLM, OPTForCausalLM
from transformers.testing_utils import require_peft, require_torch, require_torch_gpu, slow, torch_device from transformers.testing_utils import require_peft, require_torch, require_torch_gpu, slow, torch_device
from transformers.utils import is_torch_available from transformers.utils import is_torch_available
@ -300,3 +302,33 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
for model_id in self.peft_test_model_ids: for model_id in self.peft_test_model_ids:
pipe = pipeline("text-generation", model_id) pipe = pipeline("text-generation", model_id)
_ = pipe("Hello") _ = pipe("Hello")
def test_peft_add_adapter_with_state_dict(self):
"""
Simple test that tests the basic usage of PEFT model through `from_pretrained`. This test tests if
add_adapter works as expected with a state_dict being passed.
"""
from peft import LoraConfig
dummy_input = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device)
for model_id, peft_model_id in zip(self.transformers_test_model_ids, self.peft_test_model_ids):
for transformers_class in self.transformers_test_model_classes:
model = transformers_class.from_pretrained(model_id).to(torch_device)
peft_config = LoraConfig(init_lora_weights=False)
with self.assertRaises(ValueError):
model.load_adapter(peft_model_id=None)
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
dummy_state_dict = torch.load(state_dict_path)
model.load_adapter(adapter_state_dict=dummy_state_dict, peft_config=peft_config)
with self.assertRaises(ValueError):
model.load_adapter(model.load_adapter(adapter_state_dict=dummy_state_dict, peft_config=None))
self.assertTrue(self._check_lora_correctly_converted(model))
# dummy generation
_ = model.generate(input_ids=dummy_input)