[`BC 4.37 -> 4.38`] for Llama family, memory and speed (#29753)
* attempt to fix * the actual fix that works with compilation! * this? * temporary update * nit? * dispatcg to memory efficient? * update both models that have static cache support * fix copies fix compile * make sure fix * fix cohere and gemma * fix beams? * nit * slipped through the cracks * nit * nits * update * fix-copies * skip failing tests * nits
This commit is contained in:
parent
8dd4ce6f2c
commit
ff841900e4
|
@ -274,9 +274,7 @@ class CohereAttention(nn.Module):
|
|||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask
|
||||
if cache_position is not None:
|
||||
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
|
@ -559,8 +557,9 @@ class CohereSdpaAttention(CohereAttention):
|
|||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None and cache_position is not None:
|
||||
causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
|
||||
# if attention_mask is not None and cache_position is not None:
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
|
@ -692,7 +691,7 @@ class CoherePreTrainedModel(PreTrainedModel):
|
|||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["CohereDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
|
@ -715,12 +714,6 @@ class CoherePreTrainedModel(PreTrainedModel):
|
|||
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
|
||||
)
|
||||
|
||||
if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
|
||||
causal_mask = torch.full(
|
||||
(max_cache_len, max_cache_len), fill_value=True, device=self.device, dtype=torch.bool
|
||||
)
|
||||
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
|
||||
|
||||
for layer in self.model.layers:
|
||||
device = layer.input_layernorm.weight.device
|
||||
if hasattr(self.config, "_pre_quantization_dtype"):
|
||||
|
@ -899,7 +892,7 @@ class CohereModel(CoherePreTrainedModel):
|
|||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, past_seen_tokens)
|
||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
@ -967,25 +960,27 @@ class CohereModel(CoherePreTrainedModel):
|
|||
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
||||
def _update_causal_mask(self, attention_mask, input_tensor, past_seen_tokens):
|
||||
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
|
||||
batch_size, seq_length = input_tensor.shape[:2]
|
||||
dtype = input_tensor.dtype
|
||||
device = input_tensor.device
|
||||
|
||||
# support going beyond cached `max_position_embedding`
|
||||
if seq_length > self.causal_mask.shape[-1]:
|
||||
causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
|
||||
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
|
||||
|
||||
# We use the current dtype to avoid any overflows
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype
|
||||
causal_mask = causal_mask.expand(batch_size, 1, -1, -1)
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache
|
||||
target_length = self.config.max_position_embeddings
|
||||
else: # dynamic cache
|
||||
target_length = (
|
||||
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
|
||||
)
|
||||
|
||||
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
if attention_mask.dim() == 2:
|
||||
|
@ -995,8 +990,8 @@ class CohereModel(CoherePreTrainedModel):
|
|||
elif attention_mask.dim() == 4:
|
||||
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
||||
# cache. In that case, the 4D attention mask attends to the newest tokens only.
|
||||
if attention_mask.shape[-2] < past_seen_tokens + input_tensor.shape[1]:
|
||||
offset = past_seen_tokens
|
||||
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
|
||||
offset = cache_position[0]
|
||||
else:
|
||||
offset = 0
|
||||
mask_shape = attention_mask.shape
|
||||
|
|
|
@ -279,10 +279,7 @@ class GemmaAttention(nn.Module):
|
|||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
if cache_position is not None:
|
||||
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
|
||||
else:
|
||||
causal_mask = attention_mask
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
|
@ -563,8 +560,8 @@ class GemmaSdpaAttention(GemmaAttention):
|
|||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None and cache_position is not None:
|
||||
causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
|
@ -836,12 +833,6 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
|
||||
# NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
|
||||
causal_mask = torch.full(
|
||||
(config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool
|
||||
)
|
||||
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
|
@ -901,7 +892,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, past_seen_tokens)
|
||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
@ -975,26 +966,27 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
||||
def _update_causal_mask(self, attention_mask, input_tensor, past_seen_tokens):
|
||||
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
|
||||
batch_size, seq_length = input_tensor.shape[:2]
|
||||
dtype = input_tensor.dtype
|
||||
device = input_tensor.device
|
||||
|
||||
# support going beyond cached `max_position_embedding`
|
||||
if seq_length > self.causal_mask.shape[-1]:
|
||||
causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
|
||||
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
|
||||
|
||||
# We use the current dtype to avoid any overflows
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache
|
||||
target_length = self.config.max_position_embeddings
|
||||
else: # dynamic cache
|
||||
target_length = (
|
||||
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
|
||||
)
|
||||
|
||||
causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype
|
||||
causal_mask = causal_mask.expand(batch_size, 1, -1, -1)
|
||||
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
if attention_mask.dim() == 2:
|
||||
|
@ -1004,8 +996,8 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||
elif attention_mask.dim() == 4:
|
||||
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
||||
# cache. In that case, the 4D attention mask attends to the newest tokens only.
|
||||
if attention_mask.shape[-2] < past_seen_tokens + input_tensor.shape[1]:
|
||||
offset = past_seen_tokens
|
||||
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
|
||||
offset = cache_position[0]
|
||||
else:
|
||||
offset = 0
|
||||
mask_shape = attention_mask.shape
|
||||
|
|
|
@ -371,9 +371,7 @@ class LlamaAttention(nn.Module):
|
|||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask
|
||||
if cache_position is not None:
|
||||
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
|
@ -658,8 +656,9 @@ class LlamaSdpaAttention(LlamaAttention):
|
|||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None and cache_position is not None:
|
||||
causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
|
||||
# if attention_mask is not None and cache_position is not None:
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
|
@ -792,7 +791,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
|
|||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["LlamaDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
|
@ -815,12 +814,6 @@ class LlamaPreTrainedModel(PreTrainedModel):
|
|||
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
|
||||
)
|
||||
|
||||
if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
|
||||
causal_mask = torch.full(
|
||||
(max_cache_len, max_cache_len), fill_value=True, device=self.device, dtype=torch.bool
|
||||
)
|
||||
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
|
||||
|
||||
for layer in self.model.layers:
|
||||
device = layer.input_layernorm.weight.device
|
||||
if hasattr(self.config, "_pre_quantization_dtype"):
|
||||
|
@ -934,12 +927,6 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
|
||||
# NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
|
||||
causal_mask = torch.full(
|
||||
(config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool
|
||||
)
|
||||
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
|
@ -1000,7 +987,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, past_seen_tokens)
|
||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
@ -1068,25 +1055,27 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
||||
def _update_causal_mask(self, attention_mask, input_tensor, past_seen_tokens):
|
||||
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
|
||||
batch_size, seq_length = input_tensor.shape[:2]
|
||||
dtype = input_tensor.dtype
|
||||
device = input_tensor.device
|
||||
|
||||
# support going beyond cached `max_position_embedding`
|
||||
if seq_length > self.causal_mask.shape[-1]:
|
||||
causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
|
||||
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
|
||||
|
||||
# We use the current dtype to avoid any overflows
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype
|
||||
causal_mask = causal_mask.expand(batch_size, 1, -1, -1)
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache
|
||||
target_length = self.config.max_position_embeddings
|
||||
else: # dynamic cache
|
||||
target_length = (
|
||||
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
|
||||
)
|
||||
|
||||
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
if attention_mask.dim() == 2:
|
||||
|
@ -1096,8 +1085,8 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||
elif attention_mask.dim() == 4:
|
||||
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
||||
# cache. In that case, the 4D attention mask attends to the newest tokens only.
|
||||
if attention_mask.shape[-2] < past_seen_tokens + input_tensor.shape[1]:
|
||||
offset = past_seen_tokens
|
||||
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
|
||||
offset = cache_position[0]
|
||||
else:
|
||||
offset = 0
|
||||
mask_shape = attention_mask.shape
|
||||
|
|
|
@ -283,7 +283,9 @@ class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||
)
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
fx_compatible = True
|
||||
fx_compatible = (
|
||||
False # FIXME @michaelbenayoun or @fxmarty from https://github.com/huggingface/transformers/pull/29753
|
||||
)
|
||||
|
||||
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
|
||||
# This is because we are hitting edge cases with the causal_mask buffer
|
||||
|
|
|
@ -300,7 +300,9 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||
)
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
fx_compatible = True
|
||||
fx_compatible = (
|
||||
False # FIXME @michaelbenayoun or @fxmarty from https://github.com/huggingface/transformers/pull/29753
|
||||
)
|
||||
|
||||
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
|
||||
# This is because we are hitting edge cases with the causal_mask buffer
|
||||
|
|
Loading…
Reference in New Issue