[T5] Adding `model_parallel = False` to `T5ForTokenClassification` and `MT5ForTokenClassification` (#30763)

* Adding model_parallel = False

* Revert "Adding model_parallel = False"

This reverts commit ba1d99976a.

* Trainer: circumvent error for model  in which is_parallelizable is True but does not have model_parallel attribute
This commit is contained in:
Masahiro Suzuki 2024-05-14 22:39:25 +09:00 committed by GitHub
parent 9ef3884046
commit d84f34ad77
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 1 additions and 1 deletions

View File

@ -436,7 +436,7 @@ class Trainer:
"https://huggingface.co/docs/transformers/model_doc/auto"
)
if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel:
if getattr(model, "is_parallelizable", False) and getattr(model, "model_parallel", False):
self.is_model_parallel = True
else:
self.is_model_parallel = False