507 lines
22 KiB
Python
507 lines
22 KiB
Python
# coding=utf-8
|
|
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import os
|
|
import tempfile
|
|
import unittest
|
|
|
|
from huggingface_hub import hf_hub_download
|
|
|
|
from transformers import AutoModelForCausalLM, OPTForCausalLM
|
|
from transformers.testing_utils import (
|
|
require_bitsandbytes,
|
|
require_peft,
|
|
require_torch,
|
|
require_torch_gpu,
|
|
slow,
|
|
torch_device,
|
|
)
|
|
from transformers.utils import is_torch_available
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
|
|
@require_peft
|
|
@require_torch
|
|
class PeftTesterMixin:
|
|
peft_test_model_ids = ("peft-internal-testing/tiny-OPTForCausalLM-lora",)
|
|
transformers_test_model_ids = ("hf-internal-testing/tiny-random-OPTForCausalLM",)
|
|
transformers_test_model_classes = (AutoModelForCausalLM, OPTForCausalLM)
|
|
|
|
|
|
# TODO: run it with CI after PEFT release.
|
|
@slow
|
|
class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
|
"""
|
|
A testing suite that makes sure that the PeftModel class is correctly integrated into the transformers library.
|
|
"""
|
|
|
|
def _check_lora_correctly_converted(self, model):
|
|
"""
|
|
Utility method to check if the model has correctly adapters injected on it.
|
|
"""
|
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
|
|
|
is_peft_loaded = False
|
|
|
|
for _, m in model.named_modules():
|
|
if isinstance(m, BaseTunerLayer):
|
|
is_peft_loaded = True
|
|
break
|
|
|
|
return is_peft_loaded
|
|
|
|
def test_peft_from_pretrained(self):
|
|
"""
|
|
Simple test that tests the basic usage of PEFT model through `from_pretrained`.
|
|
This checks if we pass a remote folder that contains an adapter config and adapter weights, it
|
|
should correctly load a model that has adapters injected on it.
|
|
"""
|
|
for model_id in self.peft_test_model_ids:
|
|
for transformers_class in self.transformers_test_model_classes:
|
|
peft_model = transformers_class.from_pretrained(model_id).to(torch_device)
|
|
|
|
self.assertTrue(self._check_lora_correctly_converted(peft_model))
|
|
self.assertTrue(peft_model._hf_peft_config_loaded)
|
|
# dummy generation
|
|
_ = peft_model.generate(input_ids=torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device))
|
|
|
|
def test_peft_state_dict(self):
|
|
"""
|
|
Simple test that checks if the returned state dict of `get_adapter_state_dict()` method contains
|
|
the expected keys.
|
|
"""
|
|
for model_id in self.peft_test_model_ids:
|
|
for transformers_class in self.transformers_test_model_classes:
|
|
peft_model = transformers_class.from_pretrained(model_id).to(torch_device)
|
|
|
|
state_dict = peft_model.get_adapter_state_dict()
|
|
|
|
for key in state_dict.keys():
|
|
self.assertTrue("lora" in key)
|
|
|
|
def test_peft_save_pretrained(self):
|
|
"""
|
|
Test that checks various combinations of `save_pretrained` with a model that has adapters loaded
|
|
on it. This checks if the saved model contains the expected files (adapter weights and adapter config).
|
|
"""
|
|
for model_id in self.peft_test_model_ids:
|
|
for transformers_class in self.transformers_test_model_classes:
|
|
peft_model = transformers_class.from_pretrained(model_id).to(torch_device)
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
peft_model.save_pretrained(tmpdirname)
|
|
|
|
self.assertTrue("adapter_model.safetensors" in os.listdir(tmpdirname))
|
|
self.assertTrue("adapter_config.json" in os.listdir(tmpdirname))
|
|
|
|
self.assertTrue("config.json" not in os.listdir(tmpdirname))
|
|
self.assertTrue("pytorch_model.bin" not in os.listdir(tmpdirname))
|
|
self.assertTrue("model.safetensors" not in os.listdir(tmpdirname))
|
|
|
|
peft_model = transformers_class.from_pretrained(tmpdirname).to(torch_device)
|
|
self.assertTrue(self._check_lora_correctly_converted(peft_model))
|
|
|
|
peft_model.save_pretrained(tmpdirname, safe_serialization=False)
|
|
self.assertTrue("adapter_model.bin" in os.listdir(tmpdirname))
|
|
self.assertTrue("adapter_config.json" in os.listdir(tmpdirname))
|
|
|
|
peft_model = transformers_class.from_pretrained(tmpdirname).to(torch_device)
|
|
self.assertTrue(self._check_lora_correctly_converted(peft_model))
|
|
|
|
def test_peft_enable_disable_adapters(self):
|
|
"""
|
|
A test that checks if `enable_adapters` and `disable_adapters` methods work as expected.
|
|
"""
|
|
from peft import LoraConfig
|
|
|
|
dummy_input = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device)
|
|
|
|
for model_id in self.transformers_test_model_ids:
|
|
for transformers_class in self.transformers_test_model_classes:
|
|
peft_model = transformers_class.from_pretrained(model_id).to(torch_device)
|
|
|
|
peft_config = LoraConfig(init_lora_weights=False)
|
|
|
|
peft_model.add_adapter(peft_config)
|
|
|
|
peft_logits = peft_model(dummy_input).logits
|
|
|
|
peft_model.disable_adapters()
|
|
|
|
peft_logits_disabled = peft_model(dummy_input).logits
|
|
|
|
peft_model.enable_adapters()
|
|
|
|
peft_logits_enabled = peft_model(dummy_input).logits
|
|
|
|
self.assertTrue(torch.allclose(peft_logits, peft_logits_enabled, atol=1e-12, rtol=1e-12))
|
|
self.assertFalse(torch.allclose(peft_logits_enabled, peft_logits_disabled, atol=1e-12, rtol=1e-12))
|
|
|
|
def test_peft_add_adapter(self):
|
|
"""
|
|
Simple test that tests if `add_adapter` works as expected
|
|
"""
|
|
from peft import LoraConfig
|
|
|
|
for model_id in self.transformers_test_model_ids:
|
|
for transformers_class in self.transformers_test_model_classes:
|
|
model = transformers_class.from_pretrained(model_id).to(torch_device)
|
|
|
|
peft_config = LoraConfig(init_lora_weights=False)
|
|
|
|
model.add_adapter(peft_config)
|
|
|
|
self.assertTrue(self._check_lora_correctly_converted(model))
|
|
# dummy generation
|
|
_ = model.generate(input_ids=torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device))
|
|
|
|
def test_peft_add_adapter_from_pretrained(self):
|
|
"""
|
|
Simple test that tests if `add_adapter` works as expected
|
|
"""
|
|
from peft import LoraConfig
|
|
|
|
for model_id in self.transformers_test_model_ids:
|
|
for transformers_class in self.transformers_test_model_classes:
|
|
model = transformers_class.from_pretrained(model_id).to(torch_device)
|
|
|
|
peft_config = LoraConfig(init_lora_weights=False)
|
|
|
|
model.add_adapter(peft_config)
|
|
|
|
self.assertTrue(self._check_lora_correctly_converted(model))
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
model.save_pretrained(tmpdirname)
|
|
model_from_pretrained = transformers_class.from_pretrained(tmpdirname).to(torch_device)
|
|
self.assertTrue(self._check_lora_correctly_converted(model_from_pretrained))
|
|
|
|
def test_peft_add_adapter_modules_to_save(self):
|
|
"""
|
|
Simple test that tests if `add_adapter` works as expected when training with
|
|
modules to save.
|
|
"""
|
|
from peft import LoraConfig
|
|
from peft.utils import ModulesToSaveWrapper
|
|
|
|
for model_id in self.transformers_test_model_ids:
|
|
for transformers_class in self.transformers_test_model_classes:
|
|
dummy_input = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device)
|
|
|
|
model = transformers_class.from_pretrained(model_id).to(torch_device)
|
|
peft_config = LoraConfig(init_lora_weights=False, modules_to_save=["lm_head"])
|
|
model.add_adapter(peft_config)
|
|
self._check_lora_correctly_converted(model)
|
|
|
|
_has_modules_to_save_wrapper = False
|
|
for name, module in model.named_modules():
|
|
if isinstance(module, ModulesToSaveWrapper):
|
|
_has_modules_to_save_wrapper = True
|
|
self.assertTrue(module.modules_to_save.default.weight.requires_grad)
|
|
self.assertTrue("lm_head" in name)
|
|
break
|
|
|
|
self.assertTrue(_has_modules_to_save_wrapper)
|
|
state_dict = model.get_adapter_state_dict()
|
|
|
|
self.assertTrue("lm_head.weight" in state_dict.keys())
|
|
|
|
logits = model(dummy_input).logits
|
|
loss = logits.mean()
|
|
loss.backward()
|
|
|
|
for _, param in model.named_parameters():
|
|
if param.requires_grad:
|
|
self.assertTrue(param.grad is not None)
|
|
|
|
def test_peft_add_adapter_training_gradient_checkpointing(self):
|
|
"""
|
|
Simple test that tests if `add_adapter` works as expected when training with
|
|
gradient checkpointing.
|
|
"""
|
|
from peft import LoraConfig
|
|
|
|
for model_id in self.transformers_test_model_ids:
|
|
for transformers_class in self.transformers_test_model_classes:
|
|
model = transformers_class.from_pretrained(model_id).to(torch_device)
|
|
|
|
peft_config = LoraConfig(init_lora_weights=False)
|
|
|
|
model.add_adapter(peft_config)
|
|
|
|
self.assertTrue(self._check_lora_correctly_converted(model))
|
|
|
|
# When attaching adapters the input embeddings will stay frozen, this will
|
|
# lead to the output embedding having requires_grad=False.
|
|
dummy_input = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device)
|
|
frozen_output = model.get_input_embeddings()(dummy_input)
|
|
self.assertTrue(frozen_output.requires_grad is False)
|
|
|
|
model.gradient_checkpointing_enable()
|
|
|
|
# Since here we attached the hook, the input should have requires_grad to set
|
|
# properly
|
|
non_frozen_output = model.get_input_embeddings()(dummy_input)
|
|
self.assertTrue(non_frozen_output.requires_grad is True)
|
|
|
|
# To repro the Trainer issue
|
|
dummy_input.requires_grad = False
|
|
|
|
for name, param in model.named_parameters():
|
|
if "lora" in name.lower():
|
|
self.assertTrue(param.requires_grad)
|
|
|
|
logits = model(dummy_input).logits
|
|
loss = logits.mean()
|
|
loss.backward()
|
|
|
|
for name, param in model.named_parameters():
|
|
if param.requires_grad:
|
|
self.assertTrue("lora" in name.lower())
|
|
self.assertTrue(param.grad is not None)
|
|
|
|
def test_peft_add_multi_adapter(self):
|
|
"""
|
|
Simple test that tests the basic usage of PEFT model through `from_pretrained`. This test tests if
|
|
add_adapter works as expected in multi-adapter setting.
|
|
"""
|
|
from peft import LoraConfig
|
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
|
|
|
dummy_input = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device)
|
|
|
|
for model_id in self.transformers_test_model_ids:
|
|
for transformers_class in self.transformers_test_model_classes:
|
|
is_peft_loaded = False
|
|
model = transformers_class.from_pretrained(model_id).to(torch_device)
|
|
|
|
logits_original_model = model(dummy_input).logits
|
|
|
|
peft_config = LoraConfig(init_lora_weights=False)
|
|
|
|
model.add_adapter(peft_config)
|
|
|
|
logits_adapter_1 = model(dummy_input)
|
|
|
|
model.add_adapter(peft_config, adapter_name="adapter-2")
|
|
|
|
logits_adapter_2 = model(dummy_input)
|
|
|
|
for _, m in model.named_modules():
|
|
if isinstance(m, BaseTunerLayer):
|
|
is_peft_loaded = True
|
|
break
|
|
|
|
self.assertTrue(is_peft_loaded)
|
|
|
|
# dummy generation
|
|
_ = model.generate(input_ids=dummy_input)
|
|
|
|
model.set_adapter("default")
|
|
self.assertTrue(model.active_adapters() == ["default"])
|
|
self.assertTrue(model.active_adapter() == "default")
|
|
|
|
model.set_adapter("adapter-2")
|
|
self.assertTrue(model.active_adapters() == ["adapter-2"])
|
|
self.assertTrue(model.active_adapter() == "adapter-2")
|
|
|
|
# Logits comparison
|
|
self.assertFalse(
|
|
torch.allclose(logits_adapter_1.logits, logits_adapter_2.logits, atol=1e-6, rtol=1e-6)
|
|
)
|
|
self.assertFalse(torch.allclose(logits_original_model, logits_adapter_2.logits, atol=1e-6, rtol=1e-6))
|
|
|
|
model.set_adapter(["adapter-2", "default"])
|
|
self.assertTrue(model.active_adapters() == ["adapter-2", "default"])
|
|
self.assertTrue(model.active_adapter() == "adapter-2")
|
|
|
|
logits_adapter_mixed = model(dummy_input)
|
|
self.assertFalse(
|
|
torch.allclose(logits_adapter_1.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6)
|
|
)
|
|
|
|
self.assertFalse(
|
|
torch.allclose(logits_adapter_2.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6)
|
|
)
|
|
|
|
# multi active adapter saving not supported
|
|
with self.assertRaises(ValueError), tempfile.TemporaryDirectory() as tmpdirname:
|
|
model.save_pretrained(tmpdirname)
|
|
|
|
@require_torch_gpu
|
|
@require_bitsandbytes
|
|
def test_peft_from_pretrained_kwargs(self):
|
|
"""
|
|
Simple test that tests the basic usage of PEFT model through `from_pretrained` + additional kwargs
|
|
and see if the integraiton behaves as expected.
|
|
"""
|
|
for model_id in self.peft_test_model_ids:
|
|
for transformers_class in self.transformers_test_model_classes:
|
|
peft_model = transformers_class.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
|
|
|
|
module = peft_model.model.decoder.layers[0].self_attn.v_proj
|
|
self.assertTrue(module.__class__.__name__ == "Linear8bitLt")
|
|
self.assertTrue(peft_model.hf_device_map is not None)
|
|
|
|
# dummy generation
|
|
_ = peft_model.generate(input_ids=torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device))
|
|
|
|
@require_torch_gpu
|
|
@require_bitsandbytes
|
|
def test_peft_save_quantized(self):
|
|
"""
|
|
Simple test that tests the basic usage of PEFT model save_pretrained with quantized base models
|
|
"""
|
|
# 4bit
|
|
for model_id in self.peft_test_model_ids:
|
|
for transformers_class in self.transformers_test_model_classes:
|
|
peft_model = transformers_class.from_pretrained(model_id, load_in_4bit=True, device_map="auto")
|
|
|
|
module = peft_model.model.decoder.layers[0].self_attn.v_proj
|
|
self.assertTrue(module.__class__.__name__ == "Linear4bit")
|
|
self.assertTrue(peft_model.hf_device_map is not None)
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
peft_model.save_pretrained(tmpdirname)
|
|
self.assertTrue("adapter_model.safetensors" in os.listdir(tmpdirname))
|
|
self.assertTrue("adapter_config.json" in os.listdir(tmpdirname))
|
|
self.assertTrue("pytorch_model.bin" not in os.listdir(tmpdirname))
|
|
self.assertTrue("model.safetensors" not in os.listdir(tmpdirname))
|
|
|
|
# 8-bit
|
|
for model_id in self.peft_test_model_ids:
|
|
for transformers_class in self.transformers_test_model_classes:
|
|
peft_model = transformers_class.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
|
|
|
|
module = peft_model.model.decoder.layers[0].self_attn.v_proj
|
|
self.assertTrue(module.__class__.__name__ == "Linear8bitLt")
|
|
self.assertTrue(peft_model.hf_device_map is not None)
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
peft_model.save_pretrained(tmpdirname)
|
|
|
|
self.assertTrue("adapter_model.safetensors" in os.listdir(tmpdirname))
|
|
self.assertTrue("adapter_config.json" in os.listdir(tmpdirname))
|
|
self.assertTrue("pytorch_model.bin" not in os.listdir(tmpdirname))
|
|
self.assertTrue("model.safetensors" not in os.listdir(tmpdirname))
|
|
|
|
@require_torch_gpu
|
|
@require_bitsandbytes
|
|
def test_peft_save_quantized_regression(self):
|
|
"""
|
|
Simple test that tests the basic usage of PEFT model save_pretrained with quantized base models
|
|
Regression test to make sure everything works as expected before the safetensors integration.
|
|
"""
|
|
# 4bit
|
|
for model_id in self.peft_test_model_ids:
|
|
for transformers_class in self.transformers_test_model_classes:
|
|
peft_model = transformers_class.from_pretrained(model_id, load_in_4bit=True, device_map="auto")
|
|
|
|
module = peft_model.model.decoder.layers[0].self_attn.v_proj
|
|
self.assertTrue(module.__class__.__name__ == "Linear4bit")
|
|
self.assertTrue(peft_model.hf_device_map is not None)
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
peft_model.save_pretrained(tmpdirname, safe_serialization=False)
|
|
self.assertTrue("adapter_model.bin" in os.listdir(tmpdirname))
|
|
self.assertTrue("adapter_config.json" in os.listdir(tmpdirname))
|
|
self.assertTrue("pytorch_model.bin" not in os.listdir(tmpdirname))
|
|
self.assertTrue("model.safetensors" not in os.listdir(tmpdirname))
|
|
|
|
# 8-bit
|
|
for model_id in self.peft_test_model_ids:
|
|
for transformers_class in self.transformers_test_model_classes:
|
|
peft_model = transformers_class.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
|
|
|
|
module = peft_model.model.decoder.layers[0].self_attn.v_proj
|
|
self.assertTrue(module.__class__.__name__ == "Linear8bitLt")
|
|
self.assertTrue(peft_model.hf_device_map is not None)
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
peft_model.save_pretrained(tmpdirname, safe_serialization=False)
|
|
|
|
self.assertTrue("adapter_model.bin" in os.listdir(tmpdirname))
|
|
self.assertTrue("adapter_config.json" in os.listdir(tmpdirname))
|
|
self.assertTrue("pytorch_model.bin" not in os.listdir(tmpdirname))
|
|
self.assertTrue("model.safetensors" not in os.listdir(tmpdirname))
|
|
|
|
def test_peft_pipeline(self):
|
|
"""
|
|
Simple test that tests the basic usage of PEFT model + pipeline
|
|
"""
|
|
from transformers import pipeline
|
|
|
|
for model_id in self.peft_test_model_ids:
|
|
pipe = pipeline("text-generation", model_id)
|
|
_ = pipe("Hello")
|
|
|
|
def test_peft_add_adapter_with_state_dict(self):
|
|
"""
|
|
Simple test that tests the basic usage of PEFT model through `from_pretrained`. This test tests if
|
|
add_adapter works as expected with a state_dict being passed.
|
|
"""
|
|
from peft import LoraConfig
|
|
|
|
dummy_input = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device)
|
|
|
|
for model_id, peft_model_id in zip(self.transformers_test_model_ids, self.peft_test_model_ids):
|
|
for transformers_class in self.transformers_test_model_classes:
|
|
model = transformers_class.from_pretrained(model_id).to(torch_device)
|
|
|
|
peft_config = LoraConfig(init_lora_weights=False)
|
|
|
|
with self.assertRaises(ValueError):
|
|
model.load_adapter(peft_model_id=None)
|
|
|
|
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
|
|
|
|
dummy_state_dict = torch.load(state_dict_path)
|
|
|
|
model.load_adapter(adapter_state_dict=dummy_state_dict, peft_config=peft_config)
|
|
with self.assertRaises(ValueError):
|
|
model.load_adapter(model.load_adapter(adapter_state_dict=dummy_state_dict, peft_config=None))
|
|
self.assertTrue(self._check_lora_correctly_converted(model))
|
|
|
|
# dummy generation
|
|
_ = model.generate(input_ids=dummy_input)
|
|
|
|
def test_peft_from_pretrained_hub_kwargs(self):
|
|
"""
|
|
Tests different combinations of PEFT model + from_pretrained + hub kwargs
|
|
"""
|
|
peft_model_id = "peft-internal-testing/tiny-opt-lora-revision"
|
|
|
|
# This should not work
|
|
with self.assertRaises(OSError):
|
|
_ = AutoModelForCausalLM.from_pretrained(peft_model_id)
|
|
|
|
adapter_kwargs = {"revision": "test"}
|
|
|
|
# This should work
|
|
model = AutoModelForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs)
|
|
self.assertTrue(self._check_lora_correctly_converted(model))
|
|
|
|
model = OPTForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs)
|
|
self.assertTrue(self._check_lora_correctly_converted(model))
|
|
|
|
adapter_kwargs = {"revision": "main", "subfolder": "test_subfolder"}
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs)
|
|
self.assertTrue(self._check_lora_correctly_converted(model))
|
|
|
|
model = OPTForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs)
|
|
self.assertTrue(self._check_lora_correctly_converted(model))
|