support PeftMixedModel signature inspect (#28321)
* support PeftMixedModel signature inspect * import PeftMixedModel only peft>=0.7.0 * Update src/transformers/trainer.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/trainer.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/trainer.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/trainer.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/trainer.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/trainer.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * fix styling * Update src/transformers/trainer.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/trainer.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * style fixup * fix note --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
8eb74c1c89
commit
bbe30c6968
|
@ -185,7 +185,6 @@ else:
|
|||
if is_safetensors_available():
|
||||
import safetensors.torch
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import PeftModel
|
||||
|
||||
|
@ -213,7 +212,15 @@ if is_accelerate_available():
|
|||
|
||||
|
||||
def _is_peft_model(model):
|
||||
return is_peft_available() and isinstance(model, PeftModel)
|
||||
if is_peft_available():
|
||||
classes_to_check = (PeftModel,) if is_peft_available() else ()
|
||||
# Here we also check if the model is an instance of `PeftMixedModel` introduced in peft>=0.7.0: https://github.com/huggingface/transformers/pull/28321
|
||||
if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"):
|
||||
from peft import PeftMixedModel
|
||||
|
||||
classes_to_check = (*classes_to_check, PeftMixedModel)
|
||||
return isinstance(model, classes_to_check)
|
||||
return False
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -700,7 +707,11 @@ class Trainer:
|
|||
# Inspect model forward signature to keep only the arguments it accepts.
|
||||
model_to_inspect = self.model
|
||||
if _is_peft_model(self.model):
|
||||
model_to_inspect = self.model.get_base_model()
|
||||
if hasattr(self.model, "get_base_model"):
|
||||
model_to_inspect = self.model.get_base_model()
|
||||
else:
|
||||
# PeftMixedModel do not provide a `get_base_model` method
|
||||
model_to_inspect = self.model.base_model.model
|
||||
signature = inspect.signature(model_to_inspect.forward)
|
||||
self._signature_columns = list(signature.parameters.keys())
|
||||
# Labels may be named label or label_ids, the default data collator handles that.
|
||||
|
|
Loading…
Reference in New Issue