Enable multi-device for more models (#30379)
* feat: support for vitmatte * feat: support for vivit * feat: support for beit * feat: support for blip :D * feat: support for data2vec
This commit is contained in:
parent
b20b017949
commit
8b02bb6e74
|
@ -563,6 +563,7 @@ class BeitPreTrainedModel(PreTrainedModel):
|
|||
base_model_prefix = "beit"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["BeitLayer"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
|
|
@ -549,6 +549,7 @@ class BlipTextPreTrainedModel(PreTrainedModel):
|
|||
|
||||
config_class = BlipTextConfig
|
||||
base_model_prefix = "bert"
|
||||
_no_split_modules = []
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
|
|
@ -574,6 +574,7 @@ class Data2VecVisionPreTrainedModel(PreTrainedModel):
|
|||
base_model_prefix = "data2vec_vision"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Data2VecVisionLayer"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
|
|
@ -73,6 +73,7 @@ class VitMattePreTrainedModel(PreTrainedModel):
|
|||
config_class = VitMatteConfig
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = []
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Conv2d):
|
||||
|
|
|
@ -387,6 +387,7 @@ class VivitPreTrainedModel(PreTrainedModel):
|
|||
base_model_prefix = "vivit"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = []
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
|
Loading…
Reference in New Issue