Llama: fix custom 4D masks, v2 (#30348)

* 4d mask fixes

* Update custom 4D mask logic

* test moved to mixin

* extra tests 4d mask

* upd 4d mask and StaticCache handling

* added Mask4DTestHard to mistral tests

* post-rebase fixes

* test fixes for StaticCache

* make fix-copies

* upd 1 after #30476

* fix common tests

* rm elif attention_mask.dim() == 4:

* tests combined, fixed, mixtral supported

* bigbird style chg reverted

* rm if attention_mask.dim() == 2

* modeling_llama formatting chg

---------

Co-authored-by: Joao Gante <joao@huggingface.co>
This commit is contained in:
Poedator 2024-05-13 13:46:06 +02:00 committed by GitHub
parent 453893ed15
commit a0779b9e19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 541 additions and 366 deletions

View File

@ -250,7 +250,7 @@ class AttentionMaskConverter:
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed). allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
""" """
batch_size, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1] _, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
key_value_length = query_length + past_key_values_length key_value_length = query_length + past_key_values_length
is_tracing = ( is_tracing = (
@ -275,11 +275,7 @@ class AttentionMaskConverter:
ignore_causal_mask = True ignore_causal_mask = True
elif sliding_window is None or key_value_length < sliding_window: elif sliding_window is None or key_value_length < sliding_window:
if len(attention_mask.shape) == 4: if len(attention_mask.shape) == 4:
expected_shape = (batch_size, 1, query_length, key_value_length) return False
if tuple(attention_mask.shape) != expected_shape:
raise ValueError(
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
)
elif (is_training or not is_tracing) and torch.all(attention_mask == 1): elif (is_training or not is_tracing) and torch.all(attention_mask == 1):
if query_length == 1 or key_value_length == query_length: if query_length == 1 or key_value_length == query_length:
# For query_length == 1, causal attention and bi-directional attention are the same. # For query_length == 1, causal attention and bi-directional attention are the same.
@ -387,12 +383,18 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
) )
else: else:
expanded_4d_mask = attn_mask_converter.to_4d( if attention_mask.dim() == 4:
attention_mask, # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
input_shape[-1], if attention_mask.max() != 0:
dtype=inputs_embeds.dtype, raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
key_value_length=key_value_length, expanded_4d_mask = attention_mask
) else:
expanded_4d_mask = attn_mask_converter.to_4d(
attention_mask,
input_shape[-1],
dtype=inputs_embeds.dtype,
key_value_length=key_value_length,
)
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.

View File

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" PyTorch BigBirdPegasus model.""" """ PyTorch BigBirdPegasus model."""
import copy import copy
import math import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union

View File

@ -995,37 +995,27 @@ class CohereModel(CoherePreTrainedModel):
else past_seen_tokens + sequence_length + 1 else past_seen_tokens + sequence_length + 1
) )
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) if attention_mask is not None and attention_mask.dim() == 4:
if sequence_length != 1: # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
causal_mask = torch.triu(causal_mask, diagonal=1) if attention_mask.max() != 0:
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) causal_mask = attention_mask
if attention_mask is not None: else:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit causal_mask = torch.full(
if attention_mask.dim() == 2: (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1] mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0 padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype padding_mask, min_dtype
) )
elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
logger.warning_once(
"Passing a 4d mask shorter than the input length is deprecated and will be removed in "
"transformers v4.42.0"
)
offset = cache_position[0]
else:
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice
if ( if (
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None

View File

@ -1241,37 +1241,27 @@ class DbrxModel(DbrxPreTrainedModel):
else past_seen_tokens + sequence_length + 1 else past_seen_tokens + sequence_length + 1
) )
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) if attention_mask is not None and attention_mask.dim() == 4:
if sequence_length != 1: # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
causal_mask = torch.triu(causal_mask, diagonal=1) if attention_mask.max() != 0:
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) causal_mask = attention_mask
if attention_mask is not None: else:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit causal_mask = torch.full(
if attention_mask.dim() == 2: (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1] mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0 padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype padding_mask, min_dtype
) )
elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
logger.warning_once(
"Passing a 4d mask shorter than the input length is deprecated and will be removed in "
"transformers v4.42.0"
)
offset = cache_position[0]
else:
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice
if ( if (
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None

View File

@ -986,37 +986,27 @@ class GemmaModel(GemmaPreTrainedModel):
else past_seen_tokens + sequence_length + 1 else past_seen_tokens + sequence_length + 1
) )
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) if attention_mask is not None and attention_mask.dim() == 4:
if sequence_length != 1: # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
causal_mask = torch.triu(causal_mask, diagonal=1) if attention_mask.max() != 0:
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) causal_mask = attention_mask
if attention_mask is not None: else:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit causal_mask = torch.full(
if attention_mask.dim() == 2: (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1] mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0 padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype padding_mask, min_dtype
) )
elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
logger.warning_once(
"Passing a 4d mask shorter than the input length is deprecated and will be removed in "
"transformers v4.42.0"
)
offset = cache_position[0]
else:
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice
if ( if (
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None

View File

@ -1073,37 +1073,27 @@ class LlamaModel(LlamaPreTrainedModel):
else past_seen_tokens + sequence_length + 1 else past_seen_tokens + sequence_length + 1
) )
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) if attention_mask is not None and attention_mask.dim() == 4:
if sequence_length != 1: # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
causal_mask = torch.triu(causal_mask, diagonal=1) if attention_mask.max() != 0:
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) causal_mask = attention_mask
if attention_mask is not None: else:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit causal_mask = torch.full(
if attention_mask.dim() == 2: (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1] mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0 padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype padding_mask, min_dtype
) )
elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
logger.warning_once(
"Passing a 4d mask shorter than the input length is deprecated and will be removed in "
"transformers v4.42.0"
)
offset = cache_position[0]
else:
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice
if ( if (
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None

View File

@ -1052,37 +1052,27 @@ class OlmoModel(OlmoPreTrainedModel):
else past_seen_tokens + sequence_length + 1 else past_seen_tokens + sequence_length + 1
) )
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) if attention_mask is not None and attention_mask.dim() == 4:
if sequence_length != 1: # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
causal_mask = torch.triu(causal_mask, diagonal=1) if attention_mask.max() != 0:
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) causal_mask = attention_mask
if attention_mask is not None: else:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit causal_mask = torch.full(
if attention_mask.dim() == 2: (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1] mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0 padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype padding_mask, min_dtype
) )
elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
logger.warning_once(
"Passing a 4d mask shorter than the input length is deprecated and will be removed in "
"transformers v4.42.0"
)
offset = cache_position[0]
else:
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice
if ( if (
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None

View File

@ -12,8 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch LLaMA model. """ """Testing suite for the PyTorch LLaMA model."""
import gc
import tempfile import tempfile
import unittest import unittest
@ -21,7 +22,7 @@ import pytest
from packaging import version from packaging import version
from parameterized import parameterized from parameterized import parameterized
from transformers import LlamaConfig, is_torch_available, set_seed from transformers import LlamaConfig, StaticCache, is_torch_available, set_seed
from transformers.testing_utils import ( from transformers.testing_utils import (
require_bitsandbytes, require_bitsandbytes,
require_flash_attn, require_flash_attn,
@ -804,7 +805,7 @@ end
'<PRE> <SUF>\ndef main():\n factory = InterfaceManagerFactory(start=datetime.now())\n managers = []\n for i in range(10):\n managers.append(factory.build(id=i))\n <MID> class InterfaceManagerFactory(AbstractManagerFactory):\n def __init__(', '<PRE> <SUF>\ndef main():\n factory = InterfaceManagerFactory(start=datetime.now())\n managers = []\n for i in range(10):\n managers.append(factory.build(id=i))\n <MID> class InterfaceManagerFactory(AbstractManagerFactory):\n def __init__(',
'<PRE> <SUF> = 0 :=\nbegin\nsplit,\n{ intros h f,\n rw pi_1_etalisation at h,\n simp [h],\n refl\n},\n{ intro h,\n have := @quasi_adjoint C D P,\n simp [←pi_1_etalisation, this, h],\n refl\n}\nend\n <MID> /-- A quasi-prefunctoid is 1-connected iff all its etalisations are 1-connected. -/\ntheorem connected_iff_etalisation [C D : precategoroid] (P : quasi_prefunctoid C D) :\nπ₁ P = 0 ↔ ' '<PRE> <SUF> = 0 :=\nbegin\nsplit,\n{ intros h f,\n rw pi_1_etalisation at h,\n simp [h],\n refl\n},\n{ intro h,\n have := @quasi_adjoint C D P,\n simp [←pi_1_etalisation, this, h],\n refl\n}\nend\n <MID> /-- A quasi-prefunctoid is 1-connected iff all its etalisations are 1-connected. -/\ntheorem connected_iff_etalisation [C D : precategoroid] (P : quasi_prefunctoid C D) :\nπ₁ P = 0 ↔ '
] ]
EXPECTED_IDS = torch.tensor([[ 1, 32007, 822, 3349, 29918, 5464, 29918, 294, 18869, 29898,29879, 29901, 851, 29897, 1599, 851, 29901, 13, 1678, 9995, 29871, 32008, 13, 1678, 736, 1121, 13, 32009, 15941, 1661, 29899, 28599, 2687, 4890, 515, 263, 1347, 29889, 13, 13, 1678, 826, 3174, 29901, 13, 4706, 269, 29901, 450, 1347, 304, 3349, 1661, 29899, 28599, 2687, 4890, 515, 29889, 13, 13, 1678, 16969, 29901, 13, 4706, 450, 1347, 411, 1661, 29899, 28599, 2687, 4890, 6206, 29889, 13, 1678, 9995, 13, 1678, 1121, 353, 5124, 13, 1678, 363, 274, 297, 269, 29901, 13, 4706, 565, 4356, 29898, 29883, 29897, 529, 29871, 29896, 29906, 29947, 29901, 13, 9651, 1121, 4619, 274, 32010, 2]]) EXPECTED_IDS = torch.tensor([[1, 32007, 822, 3349, 29918, 5464, 29918, 294, 18869, 29898, 29879, 29901, 851, 29897, 1599, 851, 29901, 13, 1678, 9995, 29871, 32008, 13, 1678, 736, 1121, 13, 32009, 15941, 1661, 29899, 28599, 2687, 4890, 515, 263, 1347, 29889, 13, 13, 1678, 826, 3174, 29901, 13, 4706, 269, 29901, 450, 1347, 304, 3349, 1661, 29899, 28599, 2687, 4890, 515, 29889, 13, 13, 1678, 16969, 29901, 13, 4706, 450, 1347, 411, 1661, 29899, 28599, 2687, 4890, 6206, 29889, 13, 1678, 9995, 13, 1678, 1121, 353, 5124, 13, 1678, 363, 274, 297, 269, 29901, 13, 4706, 565, 4356, 29898, 29883, 29897, 529, 29871, 29896, 29906, 29947, 29901, 13, 9651, 1121, 4619, 274, 32010, 2]])
# fmt: on # fmt: on
self.assertEqual(processed_text_suffix_first, EXPECTED_TEXT) self.assertEqual(processed_text_suffix_first, EXPECTED_TEXT)
input_ids = tokenizer(self.PROMPTS[0], return_tensors="pt")["input_ids"] input_ids = tokenizer(self.PROMPTS[0], return_tensors="pt")["input_ids"]
@ -816,3 +817,253 @@ end
] ]
infilling = tokenizer.batch_decode(generated_ids) infilling = tokenizer.batch_decode(generated_ids)
self.assertEqual(infilling, EXPECTED_INFILLING) self.assertEqual(infilling, EXPECTED_INFILLING)
@slow
@require_torch_gpu
class Mask4DTestHard(unittest.TestCase):
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
def setUp(self):
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
self.model_dtype = torch.float32
self.tokenizer = LlamaTokenizer.from_pretrained(model_name)
self.model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
def get_test_data(self):
template = "my favorite {}"
items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item
batch_separate = [template.format(x) for x in items] # 3 separate lines
batch_shared_prefix = template.format(" ".join(items)) # 1 line with options concatenated
input_ids = self.tokenizer(batch_separate, return_tensors="pt").input_ids.to(torch_device)
input_ids_shared_prefix = self.tokenizer(batch_shared_prefix, return_tensors="pt").input_ids.to(torch_device)
mask_shared_prefix = torch.tensor(
[
[
[
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1],
]
]
],
device=torch_device,
)
position_ids = torch.arange(input_ids.shape[1]).tile(input_ids.shape[0], 1).to(torch_device)
# building custom positions ids based on custom mask
position_ids_shared_prefix = (mask_shared_prefix.sum(dim=-1) - 1).reshape(1, -1)
# effectively: position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device)
# inverting the mask
min_dtype = torch.finfo(self.model_dtype).min
mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=self.model_dtype) * min_dtype
return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
def test_stacked_causal_mask(self):
(
input_ids,
position_ids,
input_ids_shared_prefix,
mask_shared_prefix,
position_ids_shared_prefix,
) = self.get_test_data()
# regular batch
logits = self.model.forward(input_ids, position_ids=position_ids).logits
logits_last = logits[:, -1, :] # last tokens in each batch line
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
# single forward run with 4D custom mask
logits_shared_prefix = self.model.forward(
input_ids_shared_prefix, attention_mask=mask_shared_prefix, position_ids=position_ids_shared_prefix
).logits
logits_shared_prefix_last = logits_shared_prefix[
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], :
] # last three tokens
decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)]
self.assertEqual(decoded, decoded_shared_prefix)
def test_partial_stacked_causal_mask(self):
# Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks
(
input_ids,
position_ids,
input_ids_shared_prefix,
mask_shared_prefix,
position_ids_shared_prefix,
) = self.get_test_data()
# regular batch
logits = self.model.forward(input_ids, position_ids=position_ids).logits
logits_last = logits[:, -1, :] # last tokens in each batch line
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
# 2 forward runs with custom 4D masks
part_a = 3 # split point
input_1a = input_ids_shared_prefix[:, :part_a]
position_ids_1a = position_ids_shared_prefix[:, :part_a]
mask_1a = mask_shared_prefix[:, :, :part_a, :part_a]
outs_1a = self.model.forward(input_1a, attention_mask=mask_1a, position_ids=position_ids_1a)
past_key_values_a = outs_1a["past_key_values"]
# Case 1: we pass a 4D attention mask regarding the current sequence length (i.e. [..., seq_len, full_len])
input_1b = input_ids_shared_prefix[:, part_a:]
position_ids_1b = position_ids_shared_prefix[:, part_a:]
mask_1b = mask_shared_prefix[:, :, part_a:, :]
outs_1b = self.model.forward(
input_1b,
attention_mask=mask_1b,
position_ids=position_ids_1b,
past_key_values=past_key_values_a,
)
decoded_1b = [
self.tokenizer.decode(t)
for t in outs_1b.logits.argmax(-1)[
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a
]
]
self.assertEqual(decoded, decoded_1b)
def test_stacked_causal_mask_static_cache(self):
"""same as above but with StaticCache"""
(
input_ids,
position_ids,
input_ids_shared_prefix,
mask_shared_prefix,
position_ids_shared_prefix,
) = self.get_test_data()
# regular batch
logits = self.model.forward(input_ids, position_ids=position_ids).logits
logits_last = logits[:, -1, :] # last tokens in each batch line
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
# upgrade the model with StaticCache
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
past_key_values = StaticCache(
config=self.model.config,
max_batch_size=1,
max_cache_len=max_cache_len,
device=torch_device,
dtype=self.model.dtype,
)
padded_attention_mask = torch.nn.functional.pad(
input=mask_shared_prefix,
pad=(0, max_cache_len - mask_shared_prefix.shape[-1]),
mode="constant",
value=torch.finfo(self.model_dtype).min,
)
# single forward run with 4D custom mask
logits_shared_prefix = self.model.forward(
input_ids_shared_prefix,
attention_mask=padded_attention_mask,
position_ids=position_ids_shared_prefix,
cache_position=torch.arange(input_ids_shared_prefix.shape[-1], device=torch_device),
past_key_values=past_key_values,
).logits
logits_shared_prefix_last = logits_shared_prefix[
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], :
] # last three tokens
decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)]
self.assertEqual(decoded, decoded_shared_prefix)
def test_partial_stacked_causal_mask_static_cache(self):
# Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks
# we pass a 4D attention mask shaped [..., seq_len, full_static_cache_len])
(
input_ids,
position_ids,
input_ids_shared_prefix,
mask_shared_prefix,
position_ids_shared_prefix,
) = self.get_test_data()
# regular batch
logits = self.model.forward(input_ids, position_ids=position_ids).logits
logits_last = logits[:, -1, :] # last tokens in each batch line
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
# upgrade the model with StaticCache
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
past_key_values = StaticCache(
config=self.model.config,
max_batch_size=1,
max_cache_len=max_cache_len,
device=torch_device,
dtype=self.model.dtype,
)
# forward run for the first part of input
part_a = 3 # split point
input_1a = input_ids_shared_prefix[:, :part_a]
position_ids_1a = position_ids_shared_prefix[:, :part_a]
mask_1a = mask_shared_prefix[:, :, :part_a, :part_a]
padded_mask_1a = torch.nn.functional.pad(
input=mask_1a,
pad=(0, max_cache_len - mask_1a.shape[-1]),
mode="constant",
value=torch.finfo(self.model_dtype).min,
)
_ = self.model.forward(
input_1a,
attention_mask=padded_mask_1a,
position_ids=position_ids_1a,
cache_position=torch.arange(part_a, device=torch_device),
past_key_values=past_key_values,
)
# forward run for the second part of input
input_1b = input_ids_shared_prefix[:, part_a:]
position_ids_1b = position_ids_shared_prefix[:, part_a:]
mask_1b = mask_shared_prefix[:, :, part_a:, :]
padded_mask_1b = torch.nn.functional.pad(
input=mask_1b, pad=(0, max_cache_len - mask_1b.shape[-1]), mode="constant", value=0
)
outs_1b = self.model.forward(
input_1b,
attention_mask=padded_mask_1b,
position_ids=position_ids_1b,
cache_position=torch.arange(
part_a,
input_ids_shared_prefix.shape[-1],
device=torch_device,
),
past_key_values=past_key_values,
)
decoded_1b = [
self.tokenizer.decode(t)
for t in outs_1b.logits.argmax(-1)[
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a
]
]
self.assertEqual(decoded, decoded_1b)

View File

@ -627,3 +627,127 @@ class MistralIntegrationTest(unittest.TestCase):
del model del model
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
gc.collect() gc.collect()
@slow
@require_torch_gpu
class Mask4DTestHard(unittest.TestCase):
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
def setUp(self):
model_name = "mistralai/Mistral-7B-v0.1"
self.model_dtype = torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
self.model = MistralForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
def get_test_data(self):
template = "my favorite {}"
items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item
batch_separate = [template.format(x) for x in items] # 3 separate lines
batch_shared_prefix = template.format(" ".join(items)) # 1 line with options concatenated
input_ids = self.tokenizer(batch_separate, return_tensors="pt").input_ids.to(torch_device)
input_ids_shared_prefix = self.tokenizer(batch_shared_prefix, return_tensors="pt").input_ids.to(torch_device)
mask_shared_prefix = torch.tensor(
[
[
[
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1],
]
]
],
device=torch_device,
)
position_ids = torch.arange(input_ids.shape[1]).tile(input_ids.shape[0], 1).to(torch_device)
# building custom positions ids based on custom mask
position_ids_shared_prefix = (mask_shared_prefix.sum(dim=-1) - 1).reshape(1, -1)
# effectively: position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device)
# inverting the mask
min_dtype = torch.finfo(self.model_dtype).min
mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=self.model_dtype) * min_dtype
return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
def test_stacked_causal_mask(self):
(
input_ids,
position_ids,
input_ids_shared_prefix,
mask_shared_prefix,
position_ids_shared_prefix,
) = self.get_test_data()
# regular batch
logits = self.model.forward(input_ids, position_ids=position_ids).logits
logits_last = logits[:, -1, :] # last tokens in each batch line
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
# single forward run with 4D custom mask
logits_shared_prefix = self.model.forward(
input_ids_shared_prefix, attention_mask=mask_shared_prefix, position_ids=position_ids_shared_prefix
).logits
logits_shared_prefix_last = logits_shared_prefix[
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], :
] # last three tokens
decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)]
self.assertEqual(decoded, decoded_shared_prefix)
def test_partial_stacked_causal_mask(self):
# Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks
(
input_ids,
position_ids,
input_ids_shared_prefix,
mask_shared_prefix,
position_ids_shared_prefix,
) = self.get_test_data()
# regular batch
logits = self.model.forward(input_ids, position_ids=position_ids).logits
logits_last = logits[:, -1, :] # last tokens in each batch line
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
# 2 forward runs with custom 4D masks
part_a = 3 # split point
input_1a = input_ids_shared_prefix[:, :part_a]
position_ids_1a = position_ids_shared_prefix[:, :part_a]
mask_1a = mask_shared_prefix[:, :, :part_a, :part_a]
outs_1a = self.model.forward(input_1a, attention_mask=mask_1a, position_ids=position_ids_1a)
past_key_values_a = outs_1a["past_key_values"]
# Case 1: we pass a 4D attention mask regarding the current sequence length (i.e. [..., seq_len, full_len])
input_1b = input_ids_shared_prefix[:, part_a:]
position_ids_1b = position_ids_shared_prefix[:, part_a:]
mask_1b = mask_shared_prefix[:, :, part_a:, :]
outs_1b = self.model.forward(
input_1b, attention_mask=mask_1b, position_ids=position_ids_1b, past_key_values=past_key_values_a
)
decoded_1b = [
self.tokenizer.decode(t)
for t in outs_1b.logits.argmax(-1)[
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a
]
]
self.assertEqual(decoded, decoded_1b)

View File

@ -4277,6 +4277,80 @@ class ModelTesterMixin:
self.assertFalse(fa2_correctly_converted) self.assertFalse(fa2_correctly_converted)
def _get_custom_4d_mask_test_data(self):
# Sequence in which all but the last token is the same
input_ids = torch.tensor(
[[10, 11, 12, 13], [10, 11, 12, 14], [10, 11, 12, 15]], device=torch_device, dtype=torch.int64
)
position_ids = torch.tensor([[0, 1, 2, 3]] * 3, device=torch_device, dtype=torch.int64)
# Combining common prefix with the unique ending tokens:
input_ids_shared_prefix = torch.cat([input_ids[0][:-1], input_ids[:, -1]]).unsqueeze(0)
# Creating a 4D mask where each of the last 3 tokens do not attend to each other.
mask_shared_prefix = torch.tensor(
[
[
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 0, 0],
[1, 1, 1, 0, 1, 0],
[1, 1, 1, 0, 0, 1],
]
]
],
)
# inverting the attention mask
mask_dtype = torch.float32
min_dtype = torch.finfo(mask_dtype).min
mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=mask_dtype, device=torch_device) * min_dtype
# Creating a position_ids tensor. note the repeating figures in the end.
position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 3, 3]], device=torch_device, dtype=torch.int64)
return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
def test_custom_4d_attention_mask(self):
if len(self.all_generative_model_classes) == 0:
self.skipTest("Model architecture has no generative classes, and thus not necessarily supporting 4D masks")
for model_class in self.all_generative_model_classes:
if not model_class._supports_cache_class:
self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks")
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config).to(device=torch_device, dtype=torch.float32)
(
input_ids,
position_ids,
input_ids_shared_prefix,
mask_shared_prefix,
position_ids_shared_prefix,
) = self._get_custom_4d_mask_test_data()
logits = model.forward(input_ids, position_ids=position_ids).logits
# logits.shape == torch.Size([3, 4, ...])
logits_shared_prefix = model(
input_ids_shared_prefix,
attention_mask=mask_shared_prefix,
position_ids=position_ids_shared_prefix,
)[0]
# logits_shared_prefix.shape == torch.Size([1, 6, ...])
out_last_tokens = logits[:, -1, :] # last tokens in each batch line
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
# comparing greedily-chosen tokens:
assert torch.equal(out_last_tokens.max(axis=1).indices, out_shared_prefix_last_tokens.max(axis=1).indices)
# comparing softmax-normalized logits:
normalized_0 = F.softmax(out_last_tokens)
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
global_rng = random.Random() global_rng = random.Random()

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy import copy
import gc
import glob import glob
import json import json
import os import os
@ -53,7 +52,6 @@ from transformers.testing_utils import (
require_tf, require_tf,
require_torch, require_torch,
require_torch_accelerator, require_torch_accelerator,
require_torch_gpu,
require_torch_multi_accelerator, require_torch_multi_accelerator,
require_usr_bin_time, require_usr_bin_time,
slow, slow,
@ -2107,229 +2105,6 @@ class TestAttentionImplementation(unittest.TestCase):
self.assertTrue("PyTorch SDPA requirements in Transformers are not met" in str(cm.exception)) self.assertTrue("PyTorch SDPA requirements in Transformers are not met" in str(cm.exception))
@require_torch_gpu
class Mask4DTestBase(unittest.TestCase):
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
def get_test_data(self):
texts = ["the cat sat", "the cat had", "the cat is"]
encoded = [self.tokenizer.encode(t) for t in texts]
input_0 = torch.tensor(encoded, device=torch_device)
# tensor([[ 1, 278, 6635, 3290],
# [ 1, 278, 6635, 750],
# [ 1, 278, 6635, 338]], device='cuda:0')
position_ids_0 = torch.tensor([[0, 1, 2, 3]] * 3, device=torch_device, dtype=torch.int64)
# Combining common prefix with the unique ending tokens:
input_1 = torch.cat([input_0[0][:-1], input_0[:, -1]]).unsqueeze(0)
# tensor([[ 1, 278, 6635, 3290, 750, 338]], device='cuda:0')
# Creating a 4D mask where each of the last 3 tokens do not attend to each other.
mask_1 = torch.tensor(
[
[
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 0, 0],
[1, 1, 1, 0, 1, 0],
[1, 1, 1, 0, 0, 1],
]
]
],
device="cuda:0",
dtype=torch.int64,
)
# Creating a position_ids tensor. note the repeating figures in the end.
position_ids_1 = torch.tensor([[0, 1, 2, 3, 3, 3]], device=torch_device, dtype=torch.int64)
return input_0, position_ids_0, input_1, mask_1, position_ids_1
@require_torch_gpu
class Mask4DTestFP32(Mask4DTestBase):
def setUp(self):
model_name = "JackFram/llama-68m" # small Llama-like model from FlexFlow
self.model_dtype = torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
def test_attention(self):
"""comparing outputs of attention layer"""
# Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()
causal_mask_1 = (1 - mask_1).to(self.model_dtype) * torch.finfo(self.model_dtype).min
hid_0 = self.model.model.embed_tokens(input_0)
outs_0 = self.model.model.layers[0].self_attn.forward(hid_0, position_ids=position_ids_0)[0]
# outs_0.shape == torch.Size([3, 4, 768])
hid_1 = self.model.model.embed_tokens(input_1)
outs_1 = self.model.model.layers[0].self_attn.forward(
hid_1, attention_mask=causal_mask_1, position_ids=position_ids_1
)[0]
# outs_1.shape == torch.Size([1, 6, 768])
outs_0_last_tokens = outs_0[:, -1, :] # last tokens in each batch line
outs_1_last_tokens = outs_1[0, -3:, :] # last three tokens
torch.testing.assert_close(outs_0_last_tokens, outs_1_last_tokens)
def test_causal_model_logits(self):
"""comparing logits outputs of whole inner model"""
# Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()
logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits
logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits
logits_0_last_tokens = logits_0[:, -1, :] # last tokens in each batch line
logits_1_last_tokens = logits_1[0, -3:, :] # last three tokens
torch.testing.assert_close(logits_0_last_tokens, logits_1_last_tokens)
@require_torch_gpu
class Mask4DTestFP16(Mask4DTestBase):
test_attention = Mask4DTestFP32.test_attention
def setUp(self):
model_name = "JackFram/llama-68m" # small Llama-like model from FlexFlow
self.model_dtype = torch.float16
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
def test_causal_model_logits(self):
"""comparing logits outputs of whole inner model"""
# Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()
logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits
logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits
logits_0_last_tokens = logits_0[:, -1, :] # last tokens in each batch line
logits_1_last_tokens = logits_1[0, -3:, :] # last three tokens
indices_0 = logits_0_last_tokens.sort(descending=True).indices
indices_1 = logits_1_last_tokens.sort(descending=True).indices
# checking logits, but note relaxed tolerances for FP16
torch.testing.assert_close(logits_0_last_tokens, logits_1_last_tokens, atol=0.02, rtol=0.001)
# checking tokens order for the top tokens
for token_ids_0, token_ids_1 in zip(indices_0, indices_1):
self.assertTrue(torch.equal(token_ids_0[:128], token_ids_1[:128]))
@slow
@require_torch_gpu
class Mask4DTestHard(unittest.TestCase):
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
def setUp(self):
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
self.model_dtype = torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
def get_test_data(self):
template = "my favorite {}"
items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item
batch_0 = [template.format(x) for x in items] # 3 separate lines
batch_1 = template.format(" ".join(items)) # 1 line with options concatenated
input_0 = self.tokenizer(batch_0, return_tensors="pt").input_ids.to(torch_device)
input_1 = self.tokenizer(batch_1, return_tensors="pt").input_ids.to(torch_device)
mask_1 = torch.tensor(
[
[
[
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1],
]
]
],
device=torch_device,
dtype=torch.int64,
)
position_ids_0 = torch.arange(input_0.shape[1]).tile(input_0.shape[0], 1).to(torch_device)
# equivalent: position_ids_1 = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device)
position_ids_1 = (mask_1.sum(dim=-1) - 1).reshape(1, -1) # same but nicer
return input_0, position_ids_0, input_1, mask_1, position_ids_1
def test_stacked_causal_mask(self):
# Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()
# regular batch
logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits
logits_0_last = logits_0[:, -1, :] # last tokens in each batch line
decoded_0 = [self.tokenizer.decode(t) for t in logits_0_last.argmax(dim=-1)]
# single forward run with 4D custom mask
logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits
logits_1_last = logits_1[0, torch.where(position_ids_1 == position_ids_1.max())[1], :] # last three tokens
decoded_1 = [self.tokenizer.decode(t) for t in logits_1_last.argmax(dim=-1)]
self.assertEqual(decoded_0, decoded_1)
def test_partial_stacked_causal_mask(self):
# Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention
# masks
# Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()
# regular batch
logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits
logits_0_last = logits_0[:, -1, :] # last tokens in each batch line
decoded_0 = [self.tokenizer.decode(t) for t in logits_0_last.argmax(dim=-1)]
# 2 forward runs with custom 4D masks
part_a = 3 # split point
input_1a = input_1[:, :part_a]
position_ids_1a = position_ids_1[:, :part_a]
mask_1a = mask_1[:, :, :part_a, :part_a]
outs_1a = self.model.forward(input_1a, attention_mask=mask_1a.bool(), position_ids=position_ids_1a)
past_key_values_a = outs_1a["past_key_values"]
input_1b = input_1[:, part_a:]
position_ids_1b = position_ids_1[:, part_a:]
mask_1b = mask_1[:, :, part_a:, :]
outs_1b = self.model.forward(
input_1b, attention_mask=mask_1b.bool(), position_ids=position_ids_1b, past_key_values=past_key_values_a
)
decoded_1b = [
self.tokenizer.decode(t)
for t in outs_1b.logits.argmax(-1)[0, torch.where(position_ids_1 == position_ids_1.max())[1] - part_a]
]
self.assertEqual(decoded_0, decoded_1b)
@require_torch @require_torch
class TestTensorSharing(TestCasePlus): class TestTensorSharing(TestCasePlus):
def test_disjoint(self): def test_disjoint(self):