Enable multi-device for more models (#30379)

* feat: support for vitmatte

* feat: support for vivit

* feat: support for beit

* feat: support for blip :D

* feat: support for data2vec
This commit is contained in:
Jacky Lee 2024-04-22 02:57:27 -07:00 committed by GitHub
parent b20b017949
commit 8b02bb6e74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 5 additions and 0 deletions

View File

@ -563,6 +563,7 @@ class BeitPreTrainedModel(PreTrainedModel):
base_model_prefix = "beit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["BeitLayer"]
def _init_weights(self, module):
"""Initialize the weights"""

View File

@ -549,6 +549,7 @@ class BlipTextPreTrainedModel(PreTrainedModel):
config_class = BlipTextConfig
base_model_prefix = "bert"
_no_split_modules = []
def _init_weights(self, module):
"""Initialize the weights"""

View File

@ -574,6 +574,7 @@ class Data2VecVisionPreTrainedModel(PreTrainedModel):
base_model_prefix = "data2vec_vision"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["Data2VecVisionLayer"]
def _init_weights(self, module):
"""Initialize the weights"""

View File

@ -73,6 +73,7 @@ class VitMattePreTrainedModel(PreTrainedModel):
config_class = VitMatteConfig
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = []
def _init_weights(self, module):
if isinstance(module, nn.Conv2d):

View File

@ -387,6 +387,7 @@ class VivitPreTrainedModel(PreTrainedModel):
base_model_prefix = "vivit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = []
def _init_weights(self, module):
"""Initialize the weights"""