---
date: '2024-10-10'
description: '[[thoughts/Compression|compression]] of key-value in [[thoughts/Transformers|Tranformers]] model'
id: KV compression
modified: 2026-06-06 23:33:08 GMT-04:00
seealso:
  - '[[@li2024snapkvllmknowslooking]]'
  - '[[@ge2024modeltellsdiscardadaptive]]'
  - '[[@xiao2024efficientstreaminglanguagemodels]]'
  - '[[@cai2025pyramidkvdynamickvcache]]'
socials:
  lists: https://github.com/October2001/Awesome-KV-Cache-Compression
tags:
  - ml
  - llm
title: KV compression
created: '2024-10-10'
published: '2024-10-10'
pageLayout: default
noCite:
  - '@li2024snapkvllmknowslooking'
  - '@ge2024modeltellsdiscardadaptive'
  - '@xiao2024efficientstreaminglanguagemodels'
  - '@cai2025pyramidkvdynamickvcache'
slug: thoughts/KV-compression
permalink: https://aarnphm.xyz/thoughts/KV-compression.md
generator:
  quartz: v4.6.0
  hostedProvider: Cloudflare
  baseUrl: aarnphm.xyz
full: https://aarnphm.xyz/llms-full.txt
---
TLDR: Most algorithm determine importance through aggregating attentions over observed queries \[@zhang2023h2oheavyhitteroracleefficient; @liu2023scissorhandsexploitingpersistenceimportance\]

More recent work aggregated attention from _limited observation windows_ \[@li2024snapkvllmknowslooking; @cai2025pyramidkvdynamickvcache\]

uses top\_k to find $k$-indices of attentions per head to preserve, and evict the not-so-important ones.

Another techniques to work with KV is to offload to a central storage, to then reuse in other context.

## KV-cache layout

NHD/HND

- `NHD` <span>&rarr;</span> `(seq_len, num_heads, head_dim)`
- `HND` <span>&rarr;</span> `(num_heads, seq_len, head_dim)`

## idea

Look at past attention weights for each pair of key and value vectors
(a measure of the degree with which that KV’s representation has been queried during past attention operations)

Then select the KV with the least attention to evict

Think of LFU (least frequency used) cache management policy

the KV cache for each sequence in a particular layer is allocated on the GPU as a _# attention heads $X$ sequence length_ tensor.

> \[!tip\] scaling
>
> total memory allocation scales with the _maximum_ sequence length for all attention heads of the KV cache

## Adaptive KV-cache compression

[Model Tells You What to Discard: Adaptive KV Cache Compression for LLMs](https://arxiv.org/abs/2310.01801) \[@ge2024modeltellsdiscardadaptive\]&#x20;

## Streaming LLM

[Efficient Streaming Language Models with Attention Sinks](https://arxiv.org/abs/2309.17453) \[@xiao2024efficientstreaminglanguagemodels\]&#x20;

_Using attention sink_

Ablate attentions among layers that deemed to be less valuable to current generations.

## RocketKV

## Pyramid-KV

[PyramidKV: Dynamic KV Cache Compression based on Pyramidal Information Funneling](https://arxiv.org/abs/2406.02069) \[@cai2025pyramidkvdynamickvcache\]&#x20;

![[thoughts/images/pyramid-kv.webp]]

## Snap-KV

[implementation](https://github.com/FasterDecoding/SnapKV)

[SnapKV: LLM Knows What You are Looking for Before Generation](https://arxiv.org/abs/2404.14469) \[@li2024snapkvllmknowslooking\]&#x20;

Voting: calculating attention weights for each query within observation windows across all attention heads, then aggregate to highlight prefix positions. Formally for a single batch:

$$
\begin{aligned}
C = &\sum_{i=0}^{L_{\text{obs}}} W_{\text{obs}} [:,i,:] \\
I &= \text{Top}_{k}(C, k)
\end{aligned}
$$

```python title="FasterDecoding/SnapKV · llama\_hijack\_4\_37.py:19" showLineNumbers \{19\}
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Tuple, Union
import warnings
from transformers.cache_utils import Cache, DynamicCache
from transformers.models.llama.modeling_llama import (
    apply_rotary_pos_emb,
    repeat_kv,
)
from transformers.utils import (
    logging,
)
from snapkv.monkeypatch.snapkv_utils import init_snapkv

logger = logging.get_logger(__name__)

# https://github.com/huggingface/transformers/blob/v4.37-release/src/transformers/models/llama/modeling_llama.py
def llama_flash_attn2_forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.LongTensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Cache] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
    **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    # [SnapKV] register kv_cluster
    init_snapkv(self)
    # LlamaFlashAttention2 attention does not support output_attentions
    if "padding_mask" in kwargs:
        warnings.warn(
            "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
        )

        # overwrite attention_mask with padding_mask
        attention_mask = kwargs.pop("padding_mask")

    output_attentions = False

    bsz, q_len, _ = hidden_states.size()

    query_states = self.q_proj(hidden_states)
    key_states = self.k_proj(hidden_states)
    value_states = self.v_proj(hidden_states)

    # Flash attention requires the input to have the shape
    # batch_size x seq_length x head_dim x hidden_dim
    # therefore we just need to keep the original shape
    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
    
    kv_seq_len = key_states.shape[-2]
    # if past_key_value is not None:
    #     kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
    if past_key_value is not None:
        if self.layer_idx is None:
            raise ValueError(
                f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
                "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                "with a layer index."
            )
        if hasattr(self, "kv_seq_len"): #[SnapKV] add kv_seq_len
            if self.kv_seq_len != 0:
                kv_seq_len += self.kv_seq_len
            else:
                kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
        else:
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
    # [SnapKV] move to ahead
    key_states = repeat_kv(key_states, self.num_key_value_groups)
    value_states = repeat_kv(value_states, self.num_key_value_groups)

    if past_key_value is not None:
        cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
        # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
        # print('kv_seq_len:', kv_seq_len)
        # print('key_states.shape:', key_states.shape)
        if key_states.shape[-2] == kv_seq_len: # [SnapKV] add kv_cluster
            self.kv_seq_len = kv_seq_len # [SnapKV] register kv_seq_len
            key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states, attention_mask, self.num_key_value_groups)
            past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs)
        else:
            self.kv_seq_len += q_len
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

    # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
    # to be able to avoid many of these transpose/reshape/view.
    query_states = query_states.transpose(1, 2)
    key_states = key_states.transpose(1, 2)
    value_states = value_states.transpose(1, 2)

    dropout_rate = self.attention_dropout if self.training else 0.0

    # In PEFT, usually we cast the layer norms in float32 for training stability reasons
    # therefore the input hidden states gets silently casted in float32. Hence, we need
    # cast them back in the correct dtype just to be sure everything works as expected.
    # This might slowdown training & inference so it is recommended to not cast the LayerNorms
    # in fp32. (LlamaRMSNorm handles it correctly)

    input_dtype = query_states.dtype
    if input_dtype == torch.float32:
        if torch.is_autocast_enabled():
            target_dtype = torch.get_autocast_gpu_dtype()
        # Handle the case where the model is quantized
        elif hasattr(self.config, "_pre_quantization_dtype"):
            target_dtype = self.config._pre_quantization_dtype
        else:
            target_dtype = self.q_proj.weight.dtype

        logger.warning_once(
            f"The input hidden states seems to be silently casted in float32, this might be related to"
            f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
            f" {target_dtype}."
        )

        query_states = query_states.to(target_dtype)
        key_states = key_states.to(target_dtype)
        value_states = value_states.to(target_dtype)

    attn_output = self._flash_attention_forward(
        query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
    )

    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
    attn_output = self.o_proj(attn_output)

    if not output_attentions:
        attn_weights = None

    return attn_output, attn_weights, past_key_value

def prepare_inputs_for_generation_llama(
    self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
    if past_key_values is None: # [SnapKV]
        for layer in self.model.layers:
            layer.self_attn.kv_seq_len = 0
    if past_key_values is not None:
        if isinstance(past_key_values, Cache):
            cache_length = past_key_values.get_seq_length()
            past_length = past_key_values.seen_tokens
            max_cache_length = past_key_values.get_max_length()
        else:
            # cache_length = past_length = past_key_values[0][0].shape[2]
            # max_cache_length = None
            cache_length = past_length = self.model.layers[0].self_attn.kv_seq_len
            max_cache_length = None
        # Keep only the unprocessed tokens:
        # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
        # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
        # input)
        if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
            input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
        # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
        # input_ids based on the past_length.
        elif past_length < input_ids.shape[1]:
            input_ids = input_ids[:, past_length:]
        # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.

        # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
        if (
            max_cache_length is not None
            and attention_mask is not None
            and cache_length + input_ids.shape[1] > max_cache_length
        ):
            attention_mask = attention_mask[:, -max_cache_length:]

    position_ids = kwargs.get("position_ids", None)
    if attention_mask is not None and position_ids is None:
        # create position_ids on the fly for batch generation
        position_ids = attention_mask.long().cumsum(-1) - 1
        position_ids.masked_fill_(attention_mask == 0, 1)
        if past_key_values:
            position_ids = position_ids[:, -input_ids.shape[1] :]

    # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
    if inputs_embeds is not None and past_key_values is None:
        model_inputs = {"inputs_embeds": inputs_embeds}
    else:
        model_inputs = {"input_ids": input_ids}

    model_inputs.update(
        {
            "position_ids": position_ids,
            "past_key_values": past_key_values,
            "use_cache": kwargs.get("use_cache"),
            "attention_mask": attention_mask,
        }
    )
    return model_inputs
```

> \[!tip\] Important
>
> $k$ is defined as $\lfloor p \times L_{\text{prefix}} \rfloor$, where $p$ is the compression rates.

Hit Rate: essentially the attention features above a predefined threshold $\Theta$ to be <mark>important</mark> features.

The idea is to have two stages:

- **Vote for important features**: select important features based on important features given fixed windows.

- **Update and store the compressed KV**: concat attention features within the windows and update the KV-cache.

- clustering via pooling <span>&rArr;</span> frequent hit-rate attention
  ```python
  attn_cache = pool1d(
    attn_weights_sum, kernel_size=kernel_size, padding=kernel_size // 2, stride=1
  )
  ```

## Ada-KV

ideas: instead of uniform eviction for KV cache hit, allocate a certain budget $B_i$ per attention heads to dynamically evict certain heads

_built on-top of PyramidKV and SnapKV_

![[thoughts/images/vllm/ada-kv.webp]]

> \[!note\] Note
>
> With Ada-SnapKV, each attention layers are still assigned with a fixed compression rate (refer to the image example)

[Ada-KV: Optimizing KV Cache Eviction by Adaptive Budget Allocation for Efficient LLM Inference](https://arxiv.org/abs/2407.11550) \[@feng2025adakvoptimizingkvcache\]&#x20;

## KIVI

link: [github](https://github.com/jy-yuan/KIVI)

## KV-Compress

_variable compression rates per attention head_

source: [github](https://github.com/IsaacRe/vllm-kvcompress)

> \[!notes\] Notes
>
> A variation of [[thoughts/KV compression#Ada-KV|Ada-SnapKV]]

Motivation:

- _group-query-compression_: compress KV-cache of [[thoughts/Attention#Group-Query Attention|GQA]] without repeating it into the dimension of $\sum$ query heads.
- Modified `PagedAttention` that compute _against_ KV-cache (contains variable numbers of KVs per head)

![[thoughts/images/vllm/kv-compress-vllm.webp]]

> For vLLM, each cache block stores KV for every attention head of every layer
>
> For KV-Compress, each block only holds KVs for a single head.
> Block tables are expanded $l \times H$ so that unique block for each specific KV head and layer can be retrieved

### Query-Group Compression (QGC)

KV compression algorithm doesn’t have GQA design in mind.

- [[#Pyramid-KV|Pyramid-KV]] cache and compress KV _after_ repetition for alignment with query tensors
- Redundancy in cache before compression

> modification of eviction-based methods per groups

### Block layout and allocation

idea: adapt PagedAttention to page out cache on a _per-head, per-layer–as well as per sequence–basis_

![[thoughts/images/vllm/paged-attention-block-kv-compress.webp]]

> \[!note\]- explanation
>
> A simplified example with two KV heads and a block size of two:
>
> - KV metrics are visualized for a given cache state, highlighting blocks of a particular sequence in the decoding batch that is scheduled to evict two blocks.
> - Logical indices are displayed under the corresponding metrics slot.

#### Evict from Paged KV cache

> need to evict KV blocks instead of evict single KV attention

## automatic prefix caching

_excerpt from [github](https://github.com/vllm-project/vllm/blob/main/docs/source/automatic_prefix_caching/details.md)_

## block manager and evictor

see also: [v2](https://github.com/vllm-project/vllm/blob/main/vllm/core/block_manager.py) and [v1](https://github.com/vllm-project/vllm/blob/5eda21e773447d81ffc661ac094716420dc7b7cb/vllm/core/block_manager_v1.py), [benchmark](https://docs.google.com/document/d/1XxYUFai07ta5rE7OdtCVhLJ5J0oAxEqrGgarFdjv0Zc/edit?tab=t.0)

Reasoning for v2:

- support sliding windows attention
- lookahead slot for [[thoughts/Speculative decoding]]

