Static Cache: load models with MQA or GQA (#28975)

This commit is contained in:
Joao Gante 2024-02-13 09:58:19 +00:00 committed by GitHub
parent da20209dbc
commit 3e70a207df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 49 additions and 3 deletions

View File

@ -351,10 +351,12 @@ class StaticCache(Cache):
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.head_dim = config.hidden_size // config.num_attention_heads self.head_dim = config.hidden_size // config.num_attention_heads
self.num_heads = config.num_attention_heads self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)
self.dtype = config.torch_dtype if config.torch_dtype is not None else dtype self.dtype = config.torch_dtype if config.torch_dtype is not None else dtype
cache_shape = (max_batch_size, self.num_heads, self.max_cache_len, self.head_dim) cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.seen_tokens = 0 self.seen_tokens = 0

View File

@ -35,14 +35,16 @@ if is_torch_available():
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
DynamicCache, DynamicCache,
LlamaConfig,
LlamaForCausalLM, LlamaForCausalLM,
SinkCache, SinkCache,
StaticCache,
) )
@require_torch @require_torch
class CacheTest(unittest.TestCase): class CacheTest(unittest.TestCase):
def test_cache_equivalence(self): def test_dynamic_cache_retrocompatibility(self):
"""Tests that we can convert back and forth between the legacy cache format and DynamicCache""" """Tests that we can convert back and forth between the legacy cache format and DynamicCache"""
legacy_cache = () legacy_cache = ()
new_cache = DynamicCache() new_cache = DynamicCache()
@ -120,6 +122,48 @@ class CacheTest(unittest.TestCase):
) )
) )
def test_static_cache_mha_mqa_gqa(self):
"""
Tests that static cache works with multi-head attention (MHA), grouped query attention (GQA), and multi-query
attention (MQA)
"""
def _random_kvs(config):
# shape for key and values: (batch_size, num_heads, seq_len, head_dim)
random_keys = torch.rand(
(1, config.num_key_value_heads, 1, config.hidden_size // config.num_attention_heads),
device=torch_device,
)
random_values = torch.rand(
(1, config.num_key_value_heads, 1, config.hidden_size // config.num_attention_heads),
device=torch_device,
)
return random_keys, random_values
mha_config = LlamaConfig(num_attention_heads=32)
mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = mha_static_cache.update(
*_random_kvs(mha_config), 0, cache_kwargs={"position_ids": torch.arange(1)}
)
self.assertTrue(cached_keys.shape == (1, 32, 10, 128))
self.assertTrue(cached_values.shape == (1, 32, 10, 128))
gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4)
gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = gqa_static_cache.update(
*_random_kvs(gqa_config), 0, cache_kwargs={"position_ids": torch.arange(1)}
)
self.assertTrue(cached_keys.shape == (1, 4, 10, 128))
self.assertTrue(cached_values.shape == (1, 4, 10, 128))
mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1)
mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = mqa_static_cache.update(
*_random_kvs(mqa_config), 0, cache_kwargs={"position_ids": torch.arange(1)}
)
self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
self.assertTrue(cached_values.shape == (1, 1, 10, 128))
@require_torch_gpu @require_torch_gpu
@slow @slow