[`BC 4.37 -> 4.38`] for Llama family, memory and speed (#29753)

* attempt to fix

* the actual fix that works with compilation!

* this?

* temporary update

* nit?

* dispatcg to memory efficient?

* update both models that have static cache support

* fix copies fix compile

* make sure fix

* fix cohere and gemma

* fix beams?

* nit

* slipped through the cracks

* nit

* nits

* update

* fix-copies

* skip failing tests

* nits
This commit is contained in:
Arthur 2024-03-21 07:47:01 +09:00 committed by GitHub
parent 8dd4ce6f2c
commit ff841900e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 72 additions and 92 deletions

View File

@ -274,9 +274,7 @@ class CohereAttention(nn.Module):
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask
if cache_position is not None:
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
@ -559,8 +557,9 @@ class CohereSdpaAttention(CohereAttention):
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None and cache_position is not None:
causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
# if attention_mask is not None and cache_position is not None:
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
@ -692,7 +691,7 @@ class CoherePreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["CohereDecoderLayer"]
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
@ -715,12 +714,6 @@ class CoherePreTrainedModel(PreTrainedModel):
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
)
if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
causal_mask = torch.full(
(max_cache_len, max_cache_len), fill_value=True, device=self.device, dtype=torch.bool
)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
for layer in self.model.layers:
device = layer.input_layernorm.weight.device
if hasattr(self.config, "_pre_quantization_dtype"):
@ -899,7 +892,7 @@ class CohereModel(CoherePreTrainedModel):
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, past_seen_tokens)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
# embed positions
hidden_states = inputs_embeds
@ -967,25 +960,27 @@ class CohereModel(CoherePreTrainedModel):
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor, past_seen_tokens):
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
batch_size, seq_length = input_tensor.shape[:2]
dtype = input_tensor.dtype
device = input_tensor.device
# support going beyond cached `max_position_embedding`
if seq_length > self.causal_mask.shape[-1]:
causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
# We use the current dtype to avoid any overflows
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype
causal_mask = causal_mask.expand(batch_size, 1, -1, -1)
sequence_length = input_tensor.shape[1]
if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache
target_length = self.config.max_position_embeddings
else: # dynamic cache
target_length = (
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
)
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
@ -995,8 +990,8 @@ class CohereModel(CoherePreTrainedModel):
elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < past_seen_tokens + input_tensor.shape[1]:
offset = past_seen_tokens
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
offset = cache_position[0]
else:
offset = 0
mask_shape = attention_mask.shape

View File

@ -279,10 +279,7 @@ class GemmaAttention(nn.Module):
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
if cache_position is not None:
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
else:
causal_mask = attention_mask
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
@ -563,8 +560,8 @@ class GemmaSdpaAttention(GemmaAttention):
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None and cache_position is not None:
causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
@ -836,12 +833,6 @@ class GemmaModel(GemmaPreTrainedModel):
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
# NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
causal_mask = torch.full(
(config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool
)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
# Initialize weights and apply final processing
self.post_init()
@ -901,7 +892,7 @@ class GemmaModel(GemmaPreTrainedModel):
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, past_seen_tokens)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
# embed positions
hidden_states = inputs_embeds
@ -975,26 +966,27 @@ class GemmaModel(GemmaPreTrainedModel):
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor, past_seen_tokens):
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
batch_size, seq_length = input_tensor.shape[:2]
dtype = input_tensor.dtype
device = input_tensor.device
# support going beyond cached `max_position_embedding`
if seq_length > self.causal_mask.shape[-1]:
causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
# We use the current dtype to avoid any overflows
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache
target_length = self.config.max_position_embeddings
else: # dynamic cache
target_length = (
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
)
causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype
causal_mask = causal_mask.expand(batch_size, 1, -1, -1)
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
@ -1004,8 +996,8 @@ class GemmaModel(GemmaPreTrainedModel):
elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < past_seen_tokens + input_tensor.shape[1]:
offset = past_seen_tokens
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
offset = cache_position[0]
else:
offset = 0
mask_shape = attention_mask.shape

View File

@ -371,9 +371,7 @@ class LlamaAttention(nn.Module):
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask
if cache_position is not None:
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
@ -658,8 +656,9 @@ class LlamaSdpaAttention(LlamaAttention):
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None and cache_position is not None:
causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
# if attention_mask is not None and cache_position is not None:
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
@ -792,7 +791,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
@ -815,12 +814,6 @@ class LlamaPreTrainedModel(PreTrainedModel):
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
)
if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
causal_mask = torch.full(
(max_cache_len, max_cache_len), fill_value=True, device=self.device, dtype=torch.bool
)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
for layer in self.model.layers:
device = layer.input_layernorm.weight.device
if hasattr(self.config, "_pre_quantization_dtype"):
@ -934,12 +927,6 @@ class LlamaModel(LlamaPreTrainedModel):
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
# NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
causal_mask = torch.full(
(config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool
)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
# Initialize weights and apply final processing
self.post_init()
@ -1000,7 +987,7 @@ class LlamaModel(LlamaPreTrainedModel):
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, past_seen_tokens)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
# embed positions
hidden_states = inputs_embeds
@ -1068,25 +1055,27 @@ class LlamaModel(LlamaPreTrainedModel):
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor, past_seen_tokens):
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
batch_size, seq_length = input_tensor.shape[:2]
dtype = input_tensor.dtype
device = input_tensor.device
# support going beyond cached `max_position_embedding`
if seq_length > self.causal_mask.shape[-1]:
causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
# We use the current dtype to avoid any overflows
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype
causal_mask = causal_mask.expand(batch_size, 1, -1, -1)
sequence_length = input_tensor.shape[1]
if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache
target_length = self.config.max_position_embeddings
else: # dynamic cache
target_length = (
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
)
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
@ -1096,8 +1085,8 @@ class LlamaModel(LlamaPreTrainedModel):
elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < past_seen_tokens + input_tensor.shape[1]:
offset = past_seen_tokens
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
offset = cache_position[0]
else:
offset = 0
mask_shape = attention_mask.shape

View File

@ -283,7 +283,9 @@ class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
)
test_headmasking = False
test_pruning = False
fx_compatible = True
fx_compatible = (
False # FIXME @michaelbenayoun or @fxmarty from https://github.com/huggingface/transformers/pull/29753
)
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
# This is because we are hitting edge cases with the causal_mask buffer

View File

@ -300,7 +300,9 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
)
test_headmasking = False
test_pruning = False
fx_compatible = True
fx_compatible = (
False # FIXME @michaelbenayoun or @fxmarty from https://github.com/huggingface/transformers/pull/29753
)
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
# This is because we are hitting edge cases with the causal_mask buffer