[PEFT] Allow PEFT model dict to be loaded (#25721)
* Allow PEFT model dict to be loaded * make style * make style * Apply suggestions from code review * address comments * fixup * final change * added tests * fix test * better logic for handling if adapter has been loaded * Update tests/peft_integration/test_peft_integration.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: younesbelkada <younesbelkada@gmail.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
8b13471494
commit
0a55d9f737
|
@ -12,13 +12,14 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Optional
|
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||||
|
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
check_peft_version,
|
check_peft_version,
|
||||||
find_adapter_config_file,
|
find_adapter_config_file,
|
||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
is_peft_available,
|
is_peft_available,
|
||||||
|
is_torch_available,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -30,6 +31,11 @@ if is_accelerate_available():
|
||||||
# Minimum PEFT version supported for the integration
|
# Minimum PEFT version supported for the integration
|
||||||
MIN_PEFT_VERSION = "0.5.0"
|
MIN_PEFT_VERSION = "0.5.0"
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -61,7 +67,7 @@ class PeftAdapterMixin:
|
||||||
|
|
||||||
def load_adapter(
|
def load_adapter(
|
||||||
self,
|
self,
|
||||||
peft_model_id: str,
|
peft_model_id: Optional[str] = None,
|
||||||
adapter_name: Optional[str] = None,
|
adapter_name: Optional[str] = None,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
|
@ -69,6 +75,8 @@ class PeftAdapterMixin:
|
||||||
max_memory: Optional[str] = None,
|
max_memory: Optional[str] = None,
|
||||||
offload_folder: Optional[str] = None,
|
offload_folder: Optional[str] = None,
|
||||||
offload_index: Optional[int] = None,
|
offload_index: Optional[int] = None,
|
||||||
|
peft_config: Dict[str, Any] = None,
|
||||||
|
adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Load adapter weights from file or remote Hub folder. If you are not familiar with adapters and PEFT methods, we
|
Load adapter weights from file or remote Hub folder. If you are not familiar with adapters and PEFT methods, we
|
||||||
|
@ -77,7 +85,7 @@ class PeftAdapterMixin:
|
||||||
Requires peft as a backend to load the adapter weights.
|
Requires peft as a backend to load the adapter weights.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
peft_model_id (`str`):
|
peft_model_id (`str`, *optional*):
|
||||||
The identifier of the model to look for on the Hub, or a local path to the saved adapter config file
|
The identifier of the model to look for on the Hub, or a local path to the saved adapter config file
|
||||||
and adapter weights.
|
and adapter weights.
|
||||||
adapter_name (`str`, *optional*):
|
adapter_name (`str`, *optional*):
|
||||||
|
@ -114,6 +122,12 @@ class PeftAdapterMixin:
|
||||||
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
|
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
|
||||||
offload_index (`int`, `optional`):
|
offload_index (`int`, `optional`):
|
||||||
`offload_index` argument to be passed to `accelerate.dispatch_model` method.
|
`offload_index` argument to be passed to `accelerate.dispatch_model` method.
|
||||||
|
peft_config (`Dict[str, Any]`, *optional*):
|
||||||
|
The configuration of the adapter to add, supported adapters are non-prefix tuning and adaption prompts
|
||||||
|
methods. This argument is used in case users directly pass PEFT state dicts
|
||||||
|
adapter_state_dict (`Dict[str, torch.Tensor]`, *optional*):
|
||||||
|
The state dict of the adapter to load. This argument is used in case users directly pass PEFT state
|
||||||
|
dicts
|
||||||
"""
|
"""
|
||||||
check_peft_version(min_version=MIN_PEFT_VERSION)
|
check_peft_version(min_version=MIN_PEFT_VERSION)
|
||||||
|
|
||||||
|
@ -122,33 +136,41 @@ class PeftAdapterMixin:
|
||||||
from peft import PeftConfig, inject_adapter_in_model, load_peft_weights
|
from peft import PeftConfig, inject_adapter_in_model, load_peft_weights
|
||||||
from peft.utils import set_peft_model_state_dict
|
from peft.utils import set_peft_model_state_dict
|
||||||
|
|
||||||
if not self._hf_peft_config_loaded:
|
if self._hf_peft_config_loaded and adapter_name in self.peft_config:
|
||||||
self._hf_peft_config_loaded = True
|
|
||||||
elif adapter_name in self.peft_config:
|
|
||||||
raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
|
raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
|
||||||
|
|
||||||
adapter_config_file = find_adapter_config_file(
|
if peft_model_id is None and (adapter_state_dict is None and peft_config is None):
|
||||||
peft_model_id,
|
|
||||||
revision=revision,
|
|
||||||
token=token,
|
|
||||||
)
|
|
||||||
|
|
||||||
if adapter_config_file is None:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"adapter model file not found in {peft_model_id}. Make sure you are passing the correct path to the "
|
"You should either pass a `peft_model_id` or a `peft_config` and `adapter_state_dict` to load an adapter."
|
||||||
"adapter model."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
loaded_peft_config = PeftConfig.from_pretrained(
|
if peft_config is None:
|
||||||
peft_model_id,
|
adapter_config_file = find_adapter_config_file(
|
||||||
revision=revision,
|
peft_model_id,
|
||||||
use_auth_token=token,
|
revision=revision,
|
||||||
)
|
token=token,
|
||||||
|
)
|
||||||
|
|
||||||
|
if adapter_config_file is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"adapter model file not found in {peft_model_id}. Make sure you are passing the correct path to the "
|
||||||
|
"adapter model."
|
||||||
|
)
|
||||||
|
|
||||||
|
peft_config = PeftConfig.from_pretrained(
|
||||||
|
peft_model_id,
|
||||||
|
revision=revision,
|
||||||
|
use_auth_token=token,
|
||||||
|
)
|
||||||
|
|
||||||
# Create and add fresh new adapters into the model.
|
# Create and add fresh new adapters into the model.
|
||||||
inject_adapter_in_model(loaded_peft_config, self, adapter_name)
|
inject_adapter_in_model(peft_config, self, adapter_name)
|
||||||
|
|
||||||
adapter_state_dict = load_peft_weights(peft_model_id, revision=revision, use_auth_token=token)
|
if not self._hf_peft_config_loaded:
|
||||||
|
self._hf_peft_config_loaded = True
|
||||||
|
|
||||||
|
if peft_model_id is not None:
|
||||||
|
adapter_state_dict = load_peft_weights(peft_model_id, revision=revision, use_auth_token=token)
|
||||||
|
|
||||||
# We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility
|
# We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility
|
||||||
processed_adapter_state_dict = {}
|
processed_adapter_state_dict = {}
|
||||||
|
|
|
@ -16,6 +16,8 @@ import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM, OPTForCausalLM
|
from transformers import AutoModelForCausalLM, OPTForCausalLM
|
||||||
from transformers.testing_utils import require_peft, require_torch, require_torch_gpu, slow, torch_device
|
from transformers.testing_utils import require_peft, require_torch, require_torch_gpu, slow, torch_device
|
||||||
from transformers.utils import is_torch_available
|
from transformers.utils import is_torch_available
|
||||||
|
@ -300,3 +302,33 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
||||||
for model_id in self.peft_test_model_ids:
|
for model_id in self.peft_test_model_ids:
|
||||||
pipe = pipeline("text-generation", model_id)
|
pipe = pipeline("text-generation", model_id)
|
||||||
_ = pipe("Hello")
|
_ = pipe("Hello")
|
||||||
|
|
||||||
|
def test_peft_add_adapter_with_state_dict(self):
|
||||||
|
"""
|
||||||
|
Simple test that tests the basic usage of PEFT model through `from_pretrained`. This test tests if
|
||||||
|
add_adapter works as expected with a state_dict being passed.
|
||||||
|
"""
|
||||||
|
from peft import LoraConfig
|
||||||
|
|
||||||
|
dummy_input = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device)
|
||||||
|
|
||||||
|
for model_id, peft_model_id in zip(self.transformers_test_model_ids, self.peft_test_model_ids):
|
||||||
|
for transformers_class in self.transformers_test_model_classes:
|
||||||
|
model = transformers_class.from_pretrained(model_id).to(torch_device)
|
||||||
|
|
||||||
|
peft_config = LoraConfig(init_lora_weights=False)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
model.load_adapter(peft_model_id=None)
|
||||||
|
|
||||||
|
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
|
||||||
|
|
||||||
|
dummy_state_dict = torch.load(state_dict_path)
|
||||||
|
|
||||||
|
model.load_adapter(adapter_state_dict=dummy_state_dict, peft_config=peft_config)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
model.load_adapter(model.load_adapter(adapter_state_dict=dummy_state_dict, peft_config=None))
|
||||||
|
self.assertTrue(self._check_lora_correctly_converted(model))
|
||||||
|
|
||||||
|
# dummy generation
|
||||||
|
_ = model.generate(input_ids=dummy_input)
|
||||||
|
|
Loading…
Reference in New Issue