Skip to content

kv_cache

turboquant_vllm.kv_cache

TurboQuant-compressed KV cache for HuggingFace transformers.

Two integration modes:

  1. TurboQuantKVCache — Accuracy benchmark only (no VRAM savings). Compresses then immediately decompresses, storing lossy FP32 back into the standard DynamicCache. Measures quantization quality impact.

  2. CompressedDynamicCache — Real VRAM savings. Stores uint8 indices + fp16 norms in compressed form. Dequantizes lazily on each cache read (one layer at a time). Achieves ~2x compression vs FP16 KV cache.

Both use non-invasive method replacement: we save a reference to the original update() method and replace it with a wrapper. This avoids subclassing DynamicCache, which is fragile across transformers versions.

Usage
# Mode 1: Accuracy benchmark (no VRAM savings)
cache = DynamicCache()
tq_cache = TurboQuantKVCache(cache, head_dim=128, bits=3)

# Mode 2: Real VRAM savings
cache = DynamicCache()
compressed = CompressedDynamicCache(cache, head_dim=128, bits=3)
# In both cases, pass cache (not the wrapper) to model.generate()

Examples:

from transformers import DynamicCache

cache = DynamicCache()
tq = TurboQuantKVCache(cache, head_dim=128, bits=3)
See Also

:mod:turboquant_vllm.compressors: TurboQuantCompressorMSE and CompressedValues. arXiv 2504.19874, Section 5.2: TurboQuant algorithm reference.

Classes

TurboQuantKVCache

TurboQuantKVCache(
    cache: Any,
    head_dim: int,
    bits: int = 3,
    *,
    seed: int = 42,
    compress_keys: bool = True,
    compress_values: bool = True,
)

Transparent KV cache compression wrapper (drop-in mode).

Intercepts cache updates to compress key/value tensors before they are stored. Both keys and values use TurboQuantCompressorMSE (full MSE-optimal quantization at the configured bit-width).

This is the "drop-in" approach where standard attention (Q @ K^T) operates on decompressed keys. For the QJL-corrected inner product path (TurboQuantProd), a custom attention kernel would be needed — see TurboQuantCompressorV2.asymmetric_attention_scores().

Attributes:

Name Type Description
cache Any

The wrapped DynamicCache instance.

key_compressor TurboQuantCompressorMSE

Compressor for key tensors.

value_compressor TurboQuantCompressorMSE

Compressor for value tensors.

bits int

Quantization bits per coordinate.

head_dim int

Model head dimension.

enabled bool

Whether compression is active.

Examples:

from transformers import DynamicCache

cache = DynamicCache()
tq = TurboQuantKVCache(cache, head_dim=128, bits=3)
tq.enabled  # True

Initialize the TurboQuant KV cache wrapper.

Parameters:

Name Type Description Default
cache Any

A HuggingFace DynamicCache instance to wrap.

required
head_dim int

Dimension of each attention head.

required
bits int

Quantization bits per coordinate (default 3).

3
seed int

Random seed for reproducibility.

42
compress_keys bool

Whether to compress key tensors.

True
compress_values bool

Whether to compress value tensors.

True
Source code in src/turboquant_vllm/kv_cache.py
def __init__(
    self,
    cache: Any,
    head_dim: int,
    bits: int = 3,
    *,
    seed: int = 42,
    compress_keys: bool = True,
    compress_values: bool = True,
) -> None:
    """Initialize the TurboQuant KV cache wrapper.

    Args:
        cache: A HuggingFace DynamicCache instance to wrap.
        head_dim: Dimension of each attention head.
        bits: Quantization bits per coordinate (default 3).
        seed: Random seed for reproducibility.
        compress_keys: Whether to compress key tensors.
        compress_values: Whether to compress value tensors.
    """
    self.cache = cache
    self.head_dim = head_dim
    self.bits = bits
    self.compress_keys = compress_keys
    self.compress_values = compress_values
    self.enabled = True

    # Drop-in mode: use MSE-only for BOTH keys and values.
    # TurboQuantCompressorV2 (TurboQuantProd) allocates 1 bit to QJL correction,
    # but QJL only helps when attention calls estimate_inner_product() directly.
    # Standard attention does Q @ K^T on decompressed keys, so QJL is invisible
    # and we lose 1 bit of MSE resolution for nothing. Full 3-bit MSE gives
    # ~95% cosine sim vs ~87% with 2-bit MSE + 1-bit QJL.
    # See: https://dejan.ai/blog/turboquant/ (TurboQuant_mse for drop-in cache)
    self.key_compressor = TurboQuantCompressorMSE(head_dim, bits, seed=seed)
    self.value_compressor = TurboQuantCompressorMSE(head_dim, bits, seed=seed)

    # Patch the cache's update method
    self._original_update = cache.update
    cache.update = self._compressed_update
Functions
disable
disable() -> None

Disable compression, passing through to original update.

Useful for A/B benchmarking within the same run.

Source code in src/turboquant_vllm/kv_cache.py
def disable(self) -> None:
    """Disable compression, passing through to original update.

    Useful for A/B benchmarking within the same run.
    """
    self.enabled = False
enable
enable() -> None

Re-enable compression after disable().

Source code in src/turboquant_vllm/kv_cache.py
def enable(self) -> None:
    """Re-enable compression after disable()."""
    self.enabled = True
restore
restore() -> None

Restore the original update method on the wrapped cache.

Call this to fully unwrap the cache and remove all TurboQuant interception.

Source code in src/turboquant_vllm/kv_cache.py
def restore(self) -> None:
    """Restore the original update method on the wrapped cache.

    Call this to fully unwrap the cache and remove all TurboQuant
    interception.
    """
    self.cache.update = self._original_update

CompressedDynamicCache

CompressedDynamicCache(cache: Any, head_dim: int, bits: int = 3, *, seed: int = 42)

KV cache with real VRAM savings via compressed index storage.

Stores TurboQuant-compressed representations and dequantizes lazily on each cache read. Only one layer's decompressed tensors are held in memory at a time — previous layers are freed on the next update.

Storage per token per head (head_dim=128):

============ ======= ===== =========== =========== Mode Dtype Bytes Compression Quality ============ ======= ===== =========== =========== FP16 baseline fp16 256 1.0x — TQ3 (3-bit) uint8 132 1.94x ~95% cosine TQ4 (4-bit) nibble 68 3.76x ~97% cosine ============ ======= ===== =========== ===========

At bits=4, indices are nibble-packed (two 4-bit values per byte), nearly doubling compression over TQ3 with better quality. Float32 norms are required — fp16 causes output degradation at 10K+ token sequences due to accumulated precision loss.

Integration strategy: non-invasive method replacement (same pattern as TurboQuantKVCache). Patches update() and get_seq_length() on the wrapped DynamicCache.

Attributes:

Name Type Description
cache Any

The wrapped DynamicCache instance.

key_compressor TurboQuantCompressorMSE

Compressor for key tensors.

value_compressor TurboQuantCompressorMSE

Compressor for value tensors.

bits int

Quantization bits per coordinate.

head_dim int

Model head dimension.

enabled bool

Whether compression is active.

fused_mode bool

When True, skip decompression in update() (fused kernel reads compressed data via get_compressed()).

rotation Tensor

Shared rotation matrix [head_dim, head_dim].

centroids Tensor

Shared codebook [2^bits].

Examples:

from transformers import DynamicCache

cache = DynamicCache()
compressed = CompressedDynamicCache(cache, head_dim=128, bits=4)
compressed.vram_bytes()  # 0

Initialize the compressed KV cache wrapper.

Sets up compressors, internal storage for compressed representations, and incremental decompressed buffers. fused_mode starts disabled.

Parameters:

Name Type Description Default
cache Any

A HuggingFace DynamicCache instance to wrap.

required
head_dim int

Dimension of each attention head. Must be even when bits=4 (required for nibble packing pairs).

required
bits int

Quantization bits per coordinate (default 3). Use 4 for nibble-packed storage (3.76x compression).

3
seed int

Random seed for reproducibility.

42

Raises:

Type Description
ValueError

If bits=4 and head_dim is odd.

Source code in src/turboquant_vllm/kv_cache.py
def __init__(
    self,
    cache: Any,
    head_dim: int,
    bits: int = 3,
    *,
    seed: int = 42,
) -> None:
    """Initialize the compressed KV cache wrapper.

    Sets up compressors, internal storage for compressed representations,
    and incremental decompressed buffers. ``fused_mode`` starts disabled.

    Args:
        cache: A HuggingFace DynamicCache instance to wrap.
        head_dim: Dimension of each attention head. Must be even
            when ``bits=4`` (required for nibble packing pairs).
        bits: Quantization bits per coordinate (default 3). Use 4
            for nibble-packed storage (3.76x compression).
        seed: Random seed for reproducibility.

    Raises:
        ValueError: If ``bits=4`` and ``head_dim`` is odd.
    """
    if bits == 4 and head_dim % 2 != 0:
        msg = f"bits=4 requires even head_dim for nibble packing, got {head_dim}"
        raise ValueError(msg)

    self.cache = cache
    self.head_dim = head_dim
    self.bits = bits
    self._nibble_packed = bits == 4
    self.enabled = True

    self.key_compressor = TurboQuantCompressorMSE(head_dim, bits, seed=seed)
    self.value_compressor = TurboQuantCompressorMSE(head_dim, bits, seed=seed)

    self._compressed_keys: list[_CompressedLayer] = []
    self._compressed_values: list[_CompressedLayer] = []
    self._decompressed_k: list[torch.Tensor | None] = []
    self._decompressed_v: list[torch.Tensor | None] = []
    self._original_dtype: torch.dtype = torch.bfloat16
    self.fused_mode = False

    # Patch cache methods
    self._original_update = cache.update
    self._original_get_seq_length = cache.get_seq_length
    cache.update = self._compressed_update
    cache.get_seq_length = self._compressed_get_seq_length
Attributes
rotation property
rotation: Tensor

Shared orthogonal rotation matrix [head_dim, head_dim] fp32.

K and V use the same rotation (same seed).

Returns:

Type Description
Tensor

The rotation matrix from the key compressor's quantizer.

centroids property
centroids: Tensor

Shared Lloyd-Max codebook [2^bits] fp32.

Returns:

Type Description
Tensor

Centroid values from the key compressor's quantizer.

Functions
get_compressed
get_compressed(layer_idx: int) -> tuple[Tensor, Tensor, Tensor, Tensor]

Return compressed K and V for a layer (fused kernel API).

Provides the raw nibble-packed indices and norms without dequantization, for use by the fused TQ4 Flash Attention kernel.

Parameters:

Name Type Description Default
layer_idx int

Transformer layer index.

required

Returns:

Type Description
Tensor

(k_packed, k_norms, v_packed, v_norms) where packed tensors

Tensor

are uint8 [batch, heads, seq, head_dim//2] and norms are

Tensor

fp32 [batch, heads, seq, 1].

Source code in src/turboquant_vllm/kv_cache.py
def get_compressed(
    self, layer_idx: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Return compressed K and V for a layer (fused kernel API).

    Provides the raw nibble-packed indices and norms without
    dequantization, for use by the fused TQ4 Flash Attention kernel.

    Args:
        layer_idx: Transformer layer index.

    Returns:
        ``(k_packed, k_norms, v_packed, v_norms)`` where packed tensors
        are uint8 ``[batch, heads, seq, head_dim//2]`` and norms are
        fp32 ``[batch, heads, seq, 1]``.
    """
    k = self._compressed_keys[layer_idx]
    v = self._compressed_values[layer_idx]
    return k.indices, k.norms, v.indices, v.norms
disable
disable() -> None

Disable compression, passing through to original update.

Source code in src/turboquant_vllm/kv_cache.py
def disable(self) -> None:
    """Disable compression, passing through to original update."""
    self.enabled = False
enable
enable() -> None

Re-enable compression after disable().

Source code in src/turboquant_vllm/kv_cache.py
def enable(self) -> None:
    """Re-enable compression after disable()."""
    self.enabled = True
restore
restore() -> None

Restore original methods on the wrapped cache.

Call this to fully unwrap the cache and remove all TurboQuant interception.

Source code in src/turboquant_vllm/kv_cache.py
def restore(self) -> None:
    """Restore original methods on the wrapped cache.

    Call this to fully unwrap the cache and remove all TurboQuant
    interception.
    """
    self.cache.update = self._original_update
    self.cache.get_seq_length = self._original_get_seq_length
vram_bytes
vram_bytes() -> int

Calculate total VRAM used by compressed storage.

Returns:

Type Description
int

Total bytes across all compressed layers (keys + values).

Source code in src/turboquant_vllm/kv_cache.py
def vram_bytes(self) -> int:
    """Calculate total VRAM used by compressed storage.

    Returns:
        Total bytes across all compressed layers (keys + values).
    """
    total = 0
    for layer in [*self._compressed_keys, *self._compressed_values]:
        total += layer.indices.nelement() * layer.indices.element_size()
        total += layer.norms.nelement() * layer.norms.element_size()
    return total
baseline_vram_bytes
baseline_vram_bytes() -> int

Estimate FP16 VRAM that would be used without compression.

Accounts for nibble-packed indices by doubling the last dimension to recover the original head_dim.

Returns:

Type Description
int

Total bytes if keys and values were stored as FP16 tensors.

Source code in src/turboquant_vllm/kv_cache.py
def baseline_vram_bytes(self) -> int:
    """Estimate FP16 VRAM that would be used without compression.

    Accounts for nibble-packed indices by doubling the last
    dimension to recover the original head_dim.

    Returns:
        Total bytes if keys and values were stored as FP16 tensors.
    """
    total = 0
    for layer in [*self._compressed_keys, *self._compressed_values]:
        b, h, s, d = layer.indices.shape
        # Nibble-packed indices have d = head_dim // 2
        if layer.packed:
            d = d * 2
        total += b * h * s * d * 2  # FP16 = 2 bytes per element
    return total
compression_stats
compression_stats() -> dict[str, Any]

Return compression statistics for reporting.

Reports the true head_dim (not the packed index dimension) and includes a nibble_packed flag.

Returns:

Type Description
dict[str, Any]

Dict with layer count, sequence length, compressed/baseline

dict[str, Any]

sizes in MiB, compression ratio, packing mode, and VRAM savings.

Source code in src/turboquant_vllm/kv_cache.py
def compression_stats(self) -> dict[str, Any]:
    """Return compression statistics for reporting.

    Reports the true ``head_dim`` (not the packed index dimension)
    and includes a ``nibble_packed`` flag.

    Returns:
        Dict with layer count, sequence length, compressed/baseline
        sizes in MiB, compression ratio, packing mode, and VRAM savings.
    """
    if not self._compressed_keys:
        return {}

    compressed_bytes = self.vram_bytes()
    baseline_bytes = self.baseline_vram_bytes()
    ratio = baseline_bytes / compressed_bytes if compressed_bytes > 0 else 0.0

    layer = self._compressed_keys[0]
    b, h, s, _ = layer.indices.shape

    return {
        "num_layers": len(self._compressed_keys),
        "seq_len": s,
        "batch_size": b,
        "num_heads": h,
        "head_dim": self.head_dim,
        "bits": self.bits,
        "nibble_packed": self._nibble_packed,
        "compressed_mib": compressed_bytes / (1024 * 1024),
        "baseline_mib": baseline_bytes / (1024 * 1024),
        "compression_ratio": round(ratio, 2),
        "savings_mib": (baseline_bytes - compressed_bytes) / (1024 * 1024),
    }