support PeftMixedModel signature inspect (#28321)

* support PeftMixedModel signature inspect

* import PeftMixedModel only peft>=0.7.0

* Update src/transformers/trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update src/transformers/trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update src/transformers/trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update src/transformers/trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update src/transformers/trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update src/transformers/trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* fix styling

* Update src/transformers/trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update src/transformers/trainer.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* style fixup

* fix note

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Facico 2024-01-26 19:05:01 +08:00 committed by GitHub
parent 8eb74c1c89
commit bbe30c6968
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 14 additions and 3 deletions

View File

@ -185,7 +185,6 @@ else:
if is_safetensors_available():
import safetensors.torch
if is_peft_available():
from peft import PeftModel
@ -213,7 +212,15 @@ if is_accelerate_available():
def _is_peft_model(model):
return is_peft_available() and isinstance(model, PeftModel)
if is_peft_available():
classes_to_check = (PeftModel,) if is_peft_available() else ()
# Here we also check if the model is an instance of `PeftMixedModel` introduced in peft>=0.7.0: https://github.com/huggingface/transformers/pull/28321
if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"):
from peft import PeftMixedModel
classes_to_check = (*classes_to_check, PeftMixedModel)
return isinstance(model, classes_to_check)
return False
if TYPE_CHECKING:
@ -700,7 +707,11 @@ class Trainer:
# Inspect model forward signature to keep only the arguments it accepts.
model_to_inspect = self.model
if _is_peft_model(self.model):
model_to_inspect = self.model.get_base_model()
if hasattr(self.model, "get_base_model"):
model_to_inspect = self.model.get_base_model()
else:
# PeftMixedModel do not provide a `get_base_model` method
model_to_inspect = self.model.base_model.model
signature = inspect.signature(model_to_inspect.forward)
self._signature_columns = list(signature.parameters.keys())
# Labels may be named label or label_ids, the default data collator handles that.