fix olmo
This commit is contained in:
parent
05bdef16b6
commit
7a0e0f6b68
|
@ -653,6 +653,7 @@ class OlmoSdpaAttention(OlmoAttention):
|
||||||
value_states,
|
value_states,
|
||||||
attn_mask=causal_mask,
|
attn_mask=causal_mask,
|
||||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||||
|
is_causal=causal_mask is None and q_len > 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
@ -970,9 +971,7 @@ class OlmoModel(OlmoPreTrainedModel):
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = cache_position.unsqueeze(0)
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
causal_mask = self._update_causal_mask(
|
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)
|
||||||
attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1]
|
|
||||||
)
|
|
||||||
|
|
||||||
# embed positions
|
# embed positions
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
@ -1036,17 +1035,32 @@ class OlmoModel(OlmoPreTrainedModel):
|
||||||
attentions=all_self_attns,
|
attentions=all_self_attns,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||||
|
def _update_causal_mask(
|
||||||
|
self,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
past_seen_tokens: int,
|
||||||
|
):
|
||||||
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
||||||
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||||
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||||
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
|
||||||
def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length):
|
|
||||||
if self.config._attn_implementation == "flash_attention_2":
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
if attention_mask is not None and 0.0 in attention_mask:
|
if attention_mask is not None and 0.0 in attention_mask:
|
||||||
return attention_mask
|
return attention_mask
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if self.config._attn_implementation == "sdpa":
|
||||||
|
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
|
||||||
|
# in order to dispatch on Flash Attention 2.
|
||||||
|
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||||
|
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
dtype, device = input_tensor.dtype, input_tensor.device
|
dtype, device = input_tensor.dtype, input_tensor.device
|
||||||
min_dtype = torch.finfo(dtype).min
|
min_dtype = torch.finfo(dtype).min
|
||||||
sequence_length = input_tensor.shape[1]
|
sequence_length = input_tensor.shape[1]
|
||||||
|
@ -1054,7 +1068,9 @@ class OlmoModel(OlmoPreTrainedModel):
|
||||||
target_length = self.config.max_position_embeddings
|
target_length = self.config.max_position_embeddings
|
||||||
else: # dynamic cache
|
else: # dynamic cache
|
||||||
target_length = (
|
target_length = (
|
||||||
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1
|
attention_mask.shape[-1]
|
||||||
|
if isinstance(attention_mask, torch.Tensor)
|
||||||
|
else past_seen_tokens + sequence_length + 1
|
||||||
)
|
)
|
||||||
|
|
||||||
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||||
|
|
Loading…
Reference in New Issue