* Revert "Re-enable SDPA's FA2 path (#30070)" This reverts commit05bdef16b6
. * Revert "Fix quality Olmo + SDPA (#30302)" This reverts commitec92f983af
.
This commit is contained in:
parent
7509a0ad98
commit
acab997bef
|
@ -234,59 +234,6 @@ class AttentionMaskConverter:
|
|||
|
||||
return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
|
||||
|
||||
@staticmethod
|
||||
def _ignore_causal_mask_sdpa(
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
inputs_embeds: torch.Tensor,
|
||||
past_key_values_length: int,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
|
||||
|
||||
In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
|
||||
`key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
|
||||
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
|
||||
"""
|
||||
|
||||
batch_size, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
|
||||
key_value_length = query_length + past_key_values_length
|
||||
|
||||
is_tracing = (
|
||||
torch.jit.is_tracing()
|
||||
or isinstance(inputs_embeds, torch.fx.Proxy)
|
||||
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
|
||||
)
|
||||
|
||||
ignore_causal_mask = False
|
||||
|
||||
if attention_mask is None:
|
||||
# TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or
|
||||
# or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
|
||||
# Thus, we currently can NOT set `ignore_causal_mask = True` here. We would need a `torch._dynamo.is_exporting()` flag.
|
||||
#
|
||||
# Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` (`TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor`).
|
||||
if sliding_window is None or key_value_length < sliding_window:
|
||||
ignore_causal_mask = not is_tracing
|
||||
elif sliding_window is None or key_value_length < sliding_window:
|
||||
if len(attention_mask.shape) == 4:
|
||||
expected_shape = (batch_size, 1, query_length, key_value_length)
|
||||
if tuple(attention_mask.shape) != expected_shape:
|
||||
raise ValueError(
|
||||
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
|
||||
)
|
||||
elif not is_tracing and torch.all(attention_mask == 1):
|
||||
if query_length == 1 or key_value_length == query_length:
|
||||
# For query_length == 1, causal attention and bi-directional attention are the same.
|
||||
ignore_causal_mask = True
|
||||
|
||||
# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
|
||||
# may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/108108
|
||||
# TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3.
|
||||
|
||||
return ignore_causal_mask
|
||||
|
||||
|
||||
def _prepare_4d_causal_attention_mask(
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
|
@ -358,6 +305,7 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
|
|||
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
|
||||
|
||||
key_value_length = input_shape[-1] + past_key_values_length
|
||||
_, query_length = input_shape
|
||||
|
||||
# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
|
||||
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
|
||||
|
@ -368,12 +316,37 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
|
|||
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
|
||||
)
|
||||
|
||||
ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
ignore_causal_mask = False
|
||||
|
||||
if attention_mask is None:
|
||||
if sliding_window is None or key_value_length < sliding_window:
|
||||
ignore_causal_mask = not is_tracing
|
||||
elif sliding_window is None or key_value_length < sliding_window:
|
||||
# 4d mask is passed through
|
||||
if len(attention_mask.shape) == 4:
|
||||
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
|
||||
if tuple(attention_mask.shape) != expected_shape:
|
||||
raise ValueError(
|
||||
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
|
||||
)
|
||||
else:
|
||||
# if the 4D mask has correct shape - invert it and fill with negative infinity
|
||||
inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype)
|
||||
attention_mask = inverted_mask.masked_fill(
|
||||
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
|
||||
)
|
||||
return attention_mask
|
||||
|
||||
elif not is_tracing and torch.all(attention_mask == 1):
|
||||
if query_length == 1:
|
||||
# For query_length == 1, causal attention and bi-directional attention are the same.
|
||||
ignore_causal_mask = True
|
||||
elif key_value_length == query_length:
|
||||
ignore_causal_mask = True
|
||||
|
||||
# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
|
||||
# may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/108108
|
||||
|
||||
if ignore_causal_mask:
|
||||
expanded_4d_mask = None
|
||||
|
|
|
@ -590,15 +590,12 @@ class CohereSdpaAttention(CohereAttention):
|
|||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather
|
||||
# relying on the `is_causal` argument.
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=causal_mask is None and q_len > 1,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
@ -911,7 +908,9 @@ class CohereModel(CoherePreTrainedModel):
|
|||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1]
|
||||
)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
@ -975,31 +974,16 @@ class CohereModel(CoherePreTrainedModel):
|
|||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_seen_tokens: int,
|
||||
):
|
||||
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
||||
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
||||
|
||||
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
||||
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
||||
def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
|
||||
if self.config._attn_implementation == "sdpa":
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
|
||||
# in order to dispatch on Flash Attention 2.
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
|
||||
):
|
||||
return None
|
||||
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
sequence_length = input_tensor.shape[1]
|
||||
|
@ -1007,9 +991,7 @@ class CohereModel(CoherePreTrainedModel):
|
|||
target_length = self.config.max_position_embeddings
|
||||
else: # dynamic cache
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1
|
||||
)
|
||||
|
||||
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||
|
|
|
@ -570,15 +570,12 @@ class GemmaSdpaAttention(GemmaAttention):
|
|||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather
|
||||
# relying on the `is_causal` argument.
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=causal_mask is None and q_len > 1,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
@ -891,7 +888,9 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1]
|
||||
)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
@ -961,31 +960,16 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_seen_tokens: int,
|
||||
):
|
||||
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
||||
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
||||
|
||||
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
||||
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
||||
def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
|
||||
if self.config._attn_implementation == "sdpa":
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
|
||||
# in order to dispatch on Flash Attention 2.
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
|
||||
):
|
||||
return None
|
||||
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
sequence_length = input_tensor.shape[1]
|
||||
|
@ -993,9 +977,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||
target_length = self.config.max_position_embeddings
|
||||
else: # dynamic cache
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1
|
||||
)
|
||||
|
||||
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||
|
|
|
@ -656,6 +656,7 @@ class LlamaSdpaAttention(LlamaAttention):
|
|||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
# if attention_mask is not None and cache_position is not None:
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
|
@ -666,15 +667,12 @@ class LlamaSdpaAttention(LlamaAttention):
|
|||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather
|
||||
# relying on the `is_causal` argument.
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=causal_mask is None and q_len > 1,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
@ -989,7 +987,9 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1]
|
||||
)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
@ -1053,31 +1053,16 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_seen_tokens: int,
|
||||
):
|
||||
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
||||
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
||||
|
||||
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
||||
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
||||
def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
|
||||
if self.config._attn_implementation == "sdpa":
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
|
||||
# in order to dispatch on Flash Attention 2.
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
|
||||
):
|
||||
return None
|
||||
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
sequence_length = input_tensor.shape[1]
|
||||
|
@ -1085,9 +1070,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||
target_length = self.config.max_position_embeddings
|
||||
else: # dynamic cache
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1
|
||||
)
|
||||
|
||||
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||
|
|
|
@ -653,7 +653,6 @@ class OlmoSdpaAttention(OlmoAttention):
|
|||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=causal_mask is None and q_len > 1,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
@ -971,7 +970,9 @@ class OlmoModel(OlmoPreTrainedModel):
|
|||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1]
|
||||
)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
@ -1035,32 +1036,17 @@ class OlmoModel(OlmoPreTrainedModel):
|
|||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
||||
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_seen_tokens: int,
|
||||
):
|
||||
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
||||
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
||||
|
||||
def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
|
||||
if self.config._attn_implementation == "sdpa":
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
|
||||
# in order to dispatch on Flash Attention 2.
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
|
||||
):
|
||||
return None
|
||||
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
sequence_length = input_tensor.shape[1]
|
||||
|
@ -1068,9 +1054,7 @@ class OlmoModel(OlmoPreTrainedModel):
|
|||
target_length = self.config.max_position_embeddings
|
||||
else: # dynamic cache
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1
|
||||
)
|
||||
|
||||
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||
|
|
|
@ -3772,42 +3772,6 @@ class ModelTesterMixin:
|
|||
|
||||
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
|
||||
|
||||
@require_torch_sdpa
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_sdpa_can_dispatch_on_flash(self):
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
major, _ = compute_capability
|
||||
|
||||
if not torch.version.cuda or major < 8:
|
||||
self.skipTest("This test requires an NVIDIA GPU with compute capability >= 8.0")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_sdpa:
|
||||
self.skipTest(f"{model_class.__name__} does not support SDPA")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if config.model_type in ["llava", "llava_next", "vipllava"]:
|
||||
self.skipTest("Llava-like models currently (transformers==4.39.1) requires an attention_mask input")
|
||||
if config.model_type in ["idefics"]:
|
||||
self.skipTest("Idefics currently (transformers==4.39.1) requires an image_attention_mask input")
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa")
|
||||
model.to(torch_device)
|
||||
|
||||
inputs_dict.pop("attention_mask", None)
|
||||
inputs_dict.pop("decoder_attention_mask", None)
|
||||
|
||||
for name, inp in inputs_dict.items():
|
||||
if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]:
|
||||
inputs_dict[name] = inp.to(torch.float16)
|
||||
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
def test_eager_matches_sdpa_generate(self):
|
||||
|
|
Loading…
Reference in New Issue