diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 65c9c2fdda..17a3f0a60a 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -185,7 +185,6 @@ else: if is_safetensors_available(): import safetensors.torch - if is_peft_available(): from peft import PeftModel @@ -213,7 +212,15 @@ if is_accelerate_available(): def _is_peft_model(model): - return is_peft_available() and isinstance(model, PeftModel) + if is_peft_available(): + classes_to_check = (PeftModel,) if is_peft_available() else () + # Here we also check if the model is an instance of `PeftMixedModel` introduced in peft>=0.7.0: https://github.com/huggingface/transformers/pull/28321 + if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"): + from peft import PeftMixedModel + + classes_to_check = (*classes_to_check, PeftMixedModel) + return isinstance(model, classes_to_check) + return False if TYPE_CHECKING: @@ -700,7 +707,11 @@ class Trainer: # Inspect model forward signature to keep only the arguments it accepts. model_to_inspect = self.model if _is_peft_model(self.model): - model_to_inspect = self.model.get_base_model() + if hasattr(self.model, "get_base_model"): + model_to_inspect = self.model.get_base_model() + else: + # PeftMixedModel do not provide a `get_base_model` method + model_to_inspect = self.model.base_model.model signature = inspect.signature(model_to_inspect.forward) self._signature_columns = list(signature.parameters.keys()) # Labels may be named label or label_ids, the default data collator handles that.