Cache: add new flag to distinguish models that `Cache` but not static cache (#30800)

* jamba cache

* new flag

* generate exception
This commit is contained in:
Joao Gante 2024-05-16 12:08:35 +01:00 committed by GitHub
parent 17cc71e149
commit 9d889f870e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 23 additions and 3 deletions

View File

@ -1616,6 +1616,11 @@ class GenerationMixin:
"issue: https://github.com/huggingface/transformers/issues/28981."
)
if generation_config.cache_implementation == "static":
if not self._supports_static_cache:
raise ValueError(
"This model does not support `cache_implementation='static'`. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981"
)
model_kwargs["past_key_values"] = self._get_static_cache(batch_size, generation_config.max_length)
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)

View File

@ -1280,8 +1280,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# SDPA support
_supports_sdpa = False
# Has support for a `Cache` instance as `past_key_values`
# Has support for a `Cache` instance as `past_key_values`? Does it support a `StaticCache`?
_supports_cache_class = False
_supports_static_cache = False
@property
def dummy_inputs(self) -> Dict[str, torch.Tensor]:

View File

@ -720,6 +720,7 @@ class CoherePreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
def _init_weights(self, module):
std = self.config.initializer_range

View File

@ -938,6 +938,7 @@ class DbrxPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
def _init_weights(self, module: nn.Module):
std = self.config.initializer_range

View File

@ -703,6 +703,7 @@ class GemmaPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
def _init_weights(self, module):
std = self.config.initializer_range

View File

@ -1341,6 +1341,7 @@ class Idefics2PreTrainedModel(PreTrainedModel):
_no_split_modules = ["Idefics2VisionAttention", "Idefics2MLP", "Idefics2PerceiverLayer", "Idefics2DecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
def _init_weights(self, module):
# important: this ported version of Idefics2 isn't meant for training from scratch - only

View File

@ -1261,6 +1261,7 @@ class JambaPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache
def _init_weights(self, module):
std = self.config.initializer_range

View File

@ -799,6 +799,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
def _init_weights(self, module):
std = self.config.initializer_range

View File

@ -810,6 +810,7 @@ class MistralPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module):
std = self.config.initializer_range

View File

@ -989,6 +989,7 @@ class MixtralPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module):
std = self.config.initializer_range

View File

@ -776,6 +776,7 @@ class OlmoPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
def _init_weights(self, module):
std = self.config.initializer_range

View File

@ -457,6 +457,7 @@ class PersimmonPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["PersimmonDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
def _init_weights(self, module):
std = self.config.initializer_range

View File

@ -825,6 +825,7 @@ class PhiPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module):
std = self.config.initializer_range

View File

@ -921,6 +921,7 @@ class Phi3PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = False
_supports_cache_class = True
_version = "0.0.5"

View File

@ -821,6 +821,7 @@ class Qwen2PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module):
std = self.config.initializer_range

View File

@ -975,6 +975,7 @@ class Qwen2MoePreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module):
std = self.config.initializer_range

View File

@ -541,7 +541,6 @@ class RecurrentGemmaPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = ["cache"]
_supports_flash_attn_2 = False
_supports_sdpa = False # we can't compare with eager for now
_supports_cache_class = True
def _init_weights(self, module):
std = math.sqrt(self.config.w_init_variance_scale / self.config.conv1d_width)

View File

@ -799,6 +799,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module):
std = self.config.initializer_range

View File

@ -4365,7 +4365,7 @@ class ModelTesterMixin:
self.skipTest("Model architecture has no generative classes, and thus not necessarily supporting 4D masks")
for model_class in self.all_generative_model_classes:
if not model_class._supports_cache_class:
if not model_class._supports_static_cache:
self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks")
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config).to(device=torch_device, dtype=torch.float32)