Static Cache: load models with MQA or GQA (#28975)
This commit is contained in:
parent
da20209dbc
commit
3e70a207df
|
@ -351,10 +351,12 @@ class StaticCache(Cache):
|
||||||
self.max_batch_size = max_batch_size
|
self.max_batch_size = max_batch_size
|
||||||
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
||||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_key_value_heads = (
|
||||||
|
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
||||||
|
)
|
||||||
self.dtype = config.torch_dtype if config.torch_dtype is not None else dtype
|
self.dtype = config.torch_dtype if config.torch_dtype is not None else dtype
|
||||||
|
|
||||||
cache_shape = (max_batch_size, self.num_heads, self.max_cache_len, self.head_dim)
|
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
|
||||||
self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
||||||
self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
||||||
self.seen_tokens = 0
|
self.seen_tokens = 0
|
||||||
|
|
|
@ -35,14 +35,16 @@ if is_torch_available():
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
DynamicCache,
|
DynamicCache,
|
||||||
|
LlamaConfig,
|
||||||
LlamaForCausalLM,
|
LlamaForCausalLM,
|
||||||
SinkCache,
|
SinkCache,
|
||||||
|
StaticCache,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class CacheTest(unittest.TestCase):
|
class CacheTest(unittest.TestCase):
|
||||||
def test_cache_equivalence(self):
|
def test_dynamic_cache_retrocompatibility(self):
|
||||||
"""Tests that we can convert back and forth between the legacy cache format and DynamicCache"""
|
"""Tests that we can convert back and forth between the legacy cache format and DynamicCache"""
|
||||||
legacy_cache = ()
|
legacy_cache = ()
|
||||||
new_cache = DynamicCache()
|
new_cache = DynamicCache()
|
||||||
|
@ -120,6 +122,48 @@ class CacheTest(unittest.TestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_static_cache_mha_mqa_gqa(self):
|
||||||
|
"""
|
||||||
|
Tests that static cache works with multi-head attention (MHA), grouped query attention (GQA), and multi-query
|
||||||
|
attention (MQA)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _random_kvs(config):
|
||||||
|
# shape for key and values: (batch_size, num_heads, seq_len, head_dim)
|
||||||
|
random_keys = torch.rand(
|
||||||
|
(1, config.num_key_value_heads, 1, config.hidden_size // config.num_attention_heads),
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
random_values = torch.rand(
|
||||||
|
(1, config.num_key_value_heads, 1, config.hidden_size // config.num_attention_heads),
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
return random_keys, random_values
|
||||||
|
|
||||||
|
mha_config = LlamaConfig(num_attention_heads=32)
|
||||||
|
mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device)
|
||||||
|
cached_keys, cached_values = mha_static_cache.update(
|
||||||
|
*_random_kvs(mha_config), 0, cache_kwargs={"position_ids": torch.arange(1)}
|
||||||
|
)
|
||||||
|
self.assertTrue(cached_keys.shape == (1, 32, 10, 128))
|
||||||
|
self.assertTrue(cached_values.shape == (1, 32, 10, 128))
|
||||||
|
|
||||||
|
gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4)
|
||||||
|
gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
|
||||||
|
cached_keys, cached_values = gqa_static_cache.update(
|
||||||
|
*_random_kvs(gqa_config), 0, cache_kwargs={"position_ids": torch.arange(1)}
|
||||||
|
)
|
||||||
|
self.assertTrue(cached_keys.shape == (1, 4, 10, 128))
|
||||||
|
self.assertTrue(cached_values.shape == (1, 4, 10, 128))
|
||||||
|
|
||||||
|
mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1)
|
||||||
|
mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
|
||||||
|
cached_keys, cached_values = mqa_static_cache.update(
|
||||||
|
*_random_kvs(mqa_config), 0, cache_kwargs={"position_ids": torch.arange(1)}
|
||||||
|
)
|
||||||
|
self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
|
||||||
|
self.assertTrue(cached_values.shape == (1, 1, 10, 128))
|
||||||
|
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@slow
|
@slow
|
||||||
|
|
Loading…
Reference in New Issue