Pass attn_implementation when using AutoXXX.from_config (#30507)

* Pass attn_implementation when using AutoXXX.from_config

* Fix
This commit is contained in:
amyeroberts 2024-04-29 10:22:33 +01:00 committed by GitHub
parent 80126f98d8
commit e8acb70015
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 41 additions and 19 deletions

View File

@ -1194,9 +1194,13 @@ class Blip2Model(Blip2PreTrainedModel):
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
if config.use_decoder_only_language_model:
language_model = AutoModelForCausalLM.from_config(config.text_config)
language_model = AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
else:
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
language_model = AutoModelForSeq2SeqLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
# Update _tied_weights_keys using the base model used.
if language_model._tied_weights_keys is not None:
@ -1549,9 +1553,13 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
if config.use_decoder_only_language_model:
language_model = AutoModelForCausalLM.from_config(config.text_config)
language_model = AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
else:
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
language_model = AutoModelForSeq2SeqLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
# Update _tied_weights_keys using the base model used.
if language_model._tied_weights_keys is not None:

View File

@ -367,7 +367,9 @@ class DepthAnythingForDepthEstimation(DepthAnythingPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.backbone = AutoBackbone.from_config(config.backbone_config)
self.backbone = AutoBackbone.from_config(
config.backbone_config, attn_implementation=config._attn_implementation
)
self.neck = DepthAnythingNeck(config)
self.head = DepthAnythingDepthEstimationHead(config)

View File

@ -209,12 +209,12 @@ class EncoderDecoderModel(PreTrainedModel):
if encoder is None:
from ..auto.modeling_auto import AutoModel
encoder = AutoModel.from_config(config.encoder)
encoder = AutoModel.from_config(config.encoder, attn_implementation=config._attn_implementation)
if decoder is None:
from ..auto.modeling_auto import AutoModelForCausalLM
decoder = AutoModelForCausalLM.from_config(config.decoder)
decoder = AutoModelForCausalLM.from_config(config.decoder, attn_implementation=config._attn_implementation)
self.encoder = encoder
self.decoder = decoder

View File

@ -149,7 +149,9 @@ class FuyuForCausalLM(FuyuPreTrainedModel):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
self.language_model = AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
self.vision_embed_tokens = nn.Linear(
config.patch_size * config.patch_size * config.num_channels, config.hidden_size

View File

@ -1476,7 +1476,7 @@ class Idefics2Model(Idefics2PreTrainedModel):
self.vision_model = Idefics2VisionTransformer(config.vision_config)
self.connector = Idefics2Connector(config)
self.text_model = AutoModel.from_config(config.text_config)
self.text_model = AutoModel.from_config(config.text_config, attn_implementation=config._attn_implementation)
self.image_seq_len = config.perceiver_config.resampler_n_latents
self.image_token_id = self.config.image_token_id

View File

@ -1251,9 +1251,13 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel):
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
if config.use_decoder_only_language_model:
language_model = AutoModelForCausalLM.from_config(config.text_config)
language_model = AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
else:
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
language_model = AutoModelForSeq2SeqLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
if language_model._no_split_modules is not None:
self._no_split_modules.extend(language_model._no_split_modules)

View File

@ -506,12 +506,16 @@ class RagModel(RagPreTrainedModel):
if question_encoder is None:
from ..auto.modeling_auto import AutoModel
question_encoder = AutoModel.from_config(config.question_encoder)
question_encoder = AutoModel.from_config(
config.question_encoder, attn_implementation=config._attn_implementation
)
if generator is None:
from ..auto.modeling_auto import AutoModelForSeq2SeqLM
generator = AutoModelForSeq2SeqLM.from_config(config.generator)
generator = AutoModelForSeq2SeqLM.from_config(
config.generator, attn_implementation=config._attn_implementation
)
self.retriever = retriever
if self.retriever is not None:

View File

@ -212,10 +212,10 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
super().__init__(config)
if encoder is None:
encoder = AutoModel.from_config(config.encoder)
encoder = AutoModel.from_config(config.encoder, attn_implementation=config._attn_implementation)
if decoder is None:
decoder = AutoModelForCausalLM.from_config(config.decoder)
decoder = AutoModelForCausalLM.from_config(config.decoder, attn_implementation=config._attn_implementation)
self.encoder = encoder
self.decoder = decoder

View File

@ -190,10 +190,10 @@ class VisionEncoderDecoderModel(PreTrainedModel):
super().__init__(config)
if encoder is None:
encoder = AutoModel.from_config(config.encoder)
encoder = AutoModel.from_config(config.encoder, attn_implementation=config._attn_implementation)
if decoder is None:
decoder = AutoModelForCausalLM.from_config(config.decoder)
decoder = AutoModelForCausalLM.from_config(config.decoder, attn_implementation=config._attn_implementation)
self.encoder = encoder
self.decoder = decoder

View File

@ -185,10 +185,12 @@ class VisionTextDualEncoderModel(PreTrainedModel):
if isinstance(config.vision_config, CLIPVisionConfig):
vision_model = CLIPVisionModel(config.vision_config)
else:
vision_model = AutoModel.from_config(config.vision_config)
vision_model = AutoModel.from_config(
config.vision_config, attn_implementation=config._attn_implementation
)
if text_model is None:
text_model = AutoModel.from_config(config.text_config)
text_model = AutoModel.from_config(config.text_config, attn_implementation=config._attn_implementation)
self.vision_model = vision_model
self.text_model = text_model