Skip to content

kv_cache

turboquant_vllm.kv_cache

TurboQuant-compressed KV cache for HuggingFace transformers.

Two integration modes:

  1. TurboQuantKVCache — Accuracy benchmark only (no VRAM savings). Compresses then immediately decompresses, storing lossy FP32 back into the standard DynamicCache. Measures quantization quality impact.

  2. CompressedDynamicCache — Real VRAM savings. Stores uint8 indices + fp16 norms in compressed form. Dequantizes lazily on each cache read (one layer at a time). Supports asymmetric K/V bit-widths via k_bits and v_bits parameters.

Both use non-invasive method replacement: we save a reference to the original update() method and replace it with a wrapper. This avoids subclassing DynamicCache, which is fragile across transformers versions. Both classes support the context manager protocol (with statement) for automatic restore() on scope exit, and detect double-wrapping.

Usage
# Mode 1: Accuracy benchmark (no VRAM savings)
cache = DynamicCache()
tq_cache = TurboQuantKVCache(cache, head_dim=128, bits=3)

# Mode 2: Real VRAM savings (with context manager)
cache = DynamicCache()
with CompressedDynamicCache(cache, head_dim=128, bits=3) as compressed:
    pass  # cache.update is patched inside the block
# cache.update is restored here

Examples:

from transformers import DynamicCache

cache = DynamicCache()
tq = TurboQuantKVCache(cache, head_dim=128, bits=3)
See Also

:mod:turboquant_vllm.compressors: TurboQuantCompressorMSE and CompressedValues. arXiv 2504.19874, Section 5.2: TurboQuant algorithm reference.

Classes

TurboQuantKVCache

TurboQuantKVCache(
    cache: Any,
    head_dim: int,
    bits: int = 3,
    *,
    seed: int = 42,
    compress_keys: bool = True,
    compress_values: bool = True,
)

Transparent KV cache compression wrapper (drop-in mode).

Intercepts cache updates to compress key/value tensors before they are stored. Both keys and values use TurboQuantCompressorMSE (full MSE-optimal quantization at the configured bit-width).

This is the "drop-in" approach where standard attention (Q @ K^T) operates on decompressed keys. For the QJL-corrected inner product path (TurboQuantProd), a custom attention kernel would be needed — see TurboQuantCompressorV2.asymmetric_attention_scores().

Supports the context manager protocol for automatic restore() on scope exit, and warns if the cache is already wrapped.

Attributes:

Name Type Description
cache Any

The wrapped DynamicCache instance.

key_compressor TurboQuantCompressorMSE

Compressor for key tensors.

value_compressor TurboQuantCompressorMSE

Compressor for value tensors.

bits int

Quantization bits per coordinate.

head_dim int

Model head dimension.

enabled bool

Whether compression is active.

Examples:

from transformers import DynamicCache

cache = DynamicCache()
tq = TurboQuantKVCache(cache, head_dim=128, bits=3)
tq.enabled  # True

Initialize the TurboQuant KV cache wrapper.

Parameters:

Name Type Description Default
cache Any

A HuggingFace DynamicCache instance to wrap.

required
head_dim int

Dimension of each attention head.

required
bits int

Quantization bits per coordinate (default 3).

3
seed int

Random seed for reproducibility.

42
compress_keys bool

Whether to compress key tensors.

True
compress_values bool

Whether to compress value tensors.

True

Warns:

Type Description
UserWarning

If cache is already wrapped by a TurboQuant wrapper. Call restore() on the existing wrapper first.

Source code in src/turboquant_vllm/kv_cache.py
def __init__(
    self,
    cache: Any,
    head_dim: int,
    bits: int = 3,
    *,
    seed: int = 42,
    compress_keys: bool = True,
    compress_values: bool = True,
) -> None:
    """Initialize the TurboQuant KV cache wrapper.

    Args:
        cache: A HuggingFace DynamicCache instance to wrap.
        head_dim: Dimension of each attention head.
        bits: Quantization bits per coordinate (default 3).
        seed: Random seed for reproducibility.
        compress_keys: Whether to compress key tensors.
        compress_values: Whether to compress value tensors.

    Warns:
        UserWarning: If ``cache`` is already wrapped by a TurboQuant
            wrapper. Call ``restore()`` on the existing wrapper first.
    """
    self.cache = cache
    self.head_dim = head_dim
    self.bits = bits
    self.compress_keys = compress_keys
    self.compress_values = compress_values
    self.enabled = True

    # Drop-in mode: use MSE-only for BOTH keys and values.
    # TurboQuantCompressorV2 (TurboQuantProd) allocates 1 bit to QJL correction,
    # but QJL only helps when attention calls estimate_inner_product() directly.
    # Standard attention does Q @ K^T on decompressed keys, so QJL is invisible
    # and we lose 1 bit of MSE resolution for nothing. Full 3-bit MSE gives
    # ~95% cosine sim vs ~87% with 2-bit MSE + 1-bit QJL.
    # See: https://dejan.ai/blog/turboquant/ (TurboQuant_mse for drop-in cache)
    self.key_compressor = TurboQuantCompressorMSE(head_dim, bits, seed=seed)
    self.value_compressor = TurboQuantCompressorMSE(head_dim, bits, seed=seed)

    # Detect double-compression: if cache.update is already a bound method
    # on one of our wrapper classes, the cache is already wrapped.
    if hasattr(cache.update, "__self__") and isinstance(
        cache.update.__self__, (CompressedDynamicCache, TurboQuantKVCache)
    ):
        warnings.warn(
            "Cache is already wrapped by TurboQuant. "
            "Call restore() on the existing wrapper first.",
            UserWarning,
            stacklevel=2,
        )

    # Patch the cache's update method
    self._original_update = cache.update
    cache.update = self._compressed_update
Functions
disable
disable() -> None

Disable compression, passing through to original update.

Useful for A/B benchmarking within the same run.

Source code in src/turboquant_vllm/kv_cache.py
def disable(self) -> None:
    """Disable compression, passing through to original update.

    Useful for A/B benchmarking within the same run.
    """
    self.enabled = False
enable
enable() -> None

Re-enable compression after disable().

Source code in src/turboquant_vllm/kv_cache.py
def enable(self) -> None:
    """Re-enable compression after disable()."""
    self.enabled = True
restore
restore() -> None

Restore the original update method on the wrapped cache.

Call this to fully unwrap the cache and remove all TurboQuant interception.

Source code in src/turboquant_vllm/kv_cache.py
def restore(self) -> None:
    """Restore the original update method on the wrapped cache.

    Call this to fully unwrap the cache and remove all TurboQuant
    interception.
    """
    self.cache.update = self._original_update
__enter__
__enter__() -> TurboQuantKVCache

Enter the context manager.

Returns:

Type Description
TurboQuantKVCache

Self, for use in with ... as bindings.

Source code in src/turboquant_vllm/kv_cache.py
def __enter__(self) -> TurboQuantKVCache:
    """Enter the context manager.

    Returns:
        Self, for use in ``with ... as`` bindings.
    """
    return self
__exit__
__exit__(*exc: object) -> bool

Exit the context manager, restoring the original cache methods.

Returns:

Type Description
bool

False — exceptions are never suppressed.

Source code in src/turboquant_vllm/kv_cache.py
def __exit__(self, *exc: object) -> bool:
    """Exit the context manager, restoring the original cache methods.

    Returns:
        False — exceptions are never suppressed.
    """
    self.restore()
    return False

CompressedDynamicCache

CompressedDynamicCache(
    cache: Any,
    head_dim: int,
    bits: int | None = 3,
    *,
    k_bits: int | None = None,
    v_bits: int | None = None,
    seed: int = 42,
    model_config: Any = None,
)

KV cache with real VRAM savings via compressed index storage.

Stores TurboQuant-compressed representations and dequantizes lazily on each cache read. Only one layer's decompressed tensors are held in memory at a time — previous layers are freed on the next update.

Supports heterogeneous head dimensions for the lazy dequantized (non-fused) cache-read path via per-head_dim compressors created lazily on first use. The fused path consumes shared rotation and centroids for the primary head_dim only, so it must not be used for models with mixed head dimensions (e.g. Gemma 4: d=256 sliding, d=512 global).

Storage per token per head (head_dim=128):

============ ======= ===== =========== =========== Mode Dtype Bytes Compression Quality ============ ======= ===== =========== =========== FP16 baseline fp16 256 1.0x — TQ3 (3-bit) uint8 132 1.94x ~95% cosine TQ4 (4-bit) nibble 68 3.76x ~97% cosine ============ ======= ===== =========== ===========

At bits=4, indices are nibble-packed (two 4-bit values per byte), nearly doubling compression over TQ3 with better quality. Float32 norms are required — fp16 causes output degradation at 10K+ token sequences due to accumulated precision loss.

For models with mixed global and sliding window attention layers (e.g. Gemma-2, Gemma-3), SWA layers automatically bypass compression via the is_sliding attribute on DynamicSlidingWindowLayer. Only global attention layers are compressed. Pass model_config to enable a diagnostic warning when the cache lacks SWA metadata.

Integration strategy: non-invasive method replacement (same pattern as TurboQuantKVCache). Patches update() and get_seq_length() on the wrapped DynamicCache. Supports the context manager protocol for automatic restore() on scope exit, and warns on double-wrap. Compatible with both transformers 4.x and 5.x lazy_initialization signatures via try/except fallback in _ensure_layer_initialized.

Attributes:

Name Type Description
cache Any

The wrapped DynamicCache instance.

key_compressor TurboQuantCompressorMSE

Compressor for key tensors.

value_compressor TurboQuantCompressorMSE

Compressor for value tensors.

bits int

Quantization bits per coordinate.

head_dim int

Model head dimension.

enabled bool

Whether compression is active.

fused_mode bool

When True, skip decompression in update() (fused kernel reads compressed data via get_compressed()).

rotation Tensor

Shared rotation matrix [head_dim, head_dim].

centroids Tensor

Shared codebook [2^bits].

Examples:

from transformers import DynamicCache

cache = DynamicCache()
compressed = CompressedDynamicCache(cache, head_dim=128, bits=4)
compressed.vram_bytes()  # 0

Initialize the compressed KV cache wrapper.

Sets up per-head_dim compressors (lazily created via _get_compressors()), internal storage for compressed representations, and incremental decompressed buffers. fused_mode starts disabled. When model_config has mixed sliding/full attention layer_types, full attention layers are bypassed (with list padding) to preserve retrieval quality while allowing get_seq_length to delegate correctly.

Keys and values can use different bit-widths via k_bits and v_bits. When both are None, bits applies to both (backward compatible). Any 4-bit component requires even head_dim for nibble packing.

Parameters:

Name Type Description Default
cache Any

A HuggingFace DynamicCache instance to wrap.

required
head_dim int

Dimension of each attention head. Must be even when any component uses 4-bit (nibble packing).

required
bits int | None

Shorthand for k_bits=bits, v_bits=bits.

3
k_bits int | None

Key quantization bits (overrides bits for keys).

None
v_bits int | None

Value quantization bits (overrides bits for values).

None
seed int

Random seed for reproducibility.

42
model_config Any

Optional model config (e.g. model.config). When provided, enables detection of misconfigured caches for models with mixed global/SWA layers (e.g. Gemma).

None

Raises:

Type Description
ValueError

If no bit-width is specified (all three are None).

ValueError

If any 4-bit component has odd head_dim.

Warns:

Type Description
UserWarning

If cache is already wrapped by a TurboQuant wrapper. Call restore() on the existing wrapper first.

UserWarning

If model_config has layer_types with sliding attention entries but the cache lacks SWA layers. Pass DynamicCache(config=model.config) to fix.

Source code in src/turboquant_vllm/kv_cache.py
def __init__(
    self,
    cache: Any,
    head_dim: int,
    bits: int | None = 3,
    *,
    k_bits: int | None = None,
    v_bits: int | None = None,
    seed: int = 42,
    model_config: Any = None,
) -> None:
    """Initialize the compressed KV cache wrapper.

    Sets up per-head_dim compressors (lazily created via
    ``_get_compressors()``), internal storage for compressed
    representations, and incremental decompressed buffers.
    ``fused_mode`` starts disabled. When ``model_config`` has
    mixed sliding/full attention ``layer_types``, full attention
    layers are bypassed (with list padding) to preserve retrieval
    quality while allowing ``get_seq_length`` to delegate correctly.

    Keys and values can use different bit-widths via ``k_bits`` and
    ``v_bits``.  When both are ``None``, ``bits`` applies to both
    (backward compatible).  Any 4-bit component requires even
    ``head_dim`` for nibble packing.

    Args:
        cache: A HuggingFace DynamicCache instance to wrap.
        head_dim: Dimension of each attention head. Must be even
            when any component uses 4-bit (nibble packing).
        bits: Shorthand for ``k_bits=bits, v_bits=bits``.
        k_bits: Key quantization bits (overrides ``bits`` for keys).
        v_bits: Value quantization bits (overrides ``bits`` for values).
        seed: Random seed for reproducibility.
        model_config: Optional model config (e.g. ``model.config``).
            When provided, enables detection of misconfigured caches
            for models with mixed global/SWA layers (e.g. Gemma).

    Raises:
        ValueError: If no bit-width is specified (all three are None).
        ValueError: If any 4-bit component has odd ``head_dim``.

    Warns:
        UserWarning: If ``cache`` is already wrapped by a TurboQuant
            wrapper. Call ``restore()`` on the existing wrapper first.
        UserWarning: If ``model_config`` has ``layer_types`` with
            sliding attention entries but the cache lacks SWA layers.
            Pass ``DynamicCache(config=model.config)`` to fix.
    """
    # Resolve per-component bit-widths
    resolved_k = k_bits if k_bits is not None else bits
    resolved_v = v_bits if v_bits is not None else bits

    if resolved_k is None or resolved_v is None:
        msg = (
            "No bit-width specified. Provide `bits` as shorthand, "
            "or `k_bits` and `v_bits` individually."
        )
        raise ValueError(msg)

    if resolved_k == 4 and head_dim % 2 != 0:
        msg = f"k_bits=4 requires even head_dim for nibble packing, got {head_dim}"
        raise ValueError(msg)
    if resolved_v == 4 and head_dim % 2 != 0:
        msg = f"v_bits=4 requires even head_dim for nibble packing, got {head_dim}"
        raise ValueError(msg)

    self.cache = cache
    self.head_dim = head_dim
    self.bits = resolved_k  # backward compat: bits reflects k_bits
    self.k_bits = resolved_k
    self.v_bits = resolved_v
    self._k_nibble_packed = resolved_k == 4
    self._v_nibble_packed = resolved_v == 4
    self._seed = seed
    self.enabled = True

    # Per-head_dim compressors for heterogeneous architectures
    # (Gemma 4: d=256 sliding, d=512 global). Created lazily via
    # _get_compressors() on first use of each head_dim.
    self._key_compressors: dict[int, TurboQuantCompressorMSE] = {}
    self._value_compressors: dict[int, TurboQuantCompressorMSE] = {}
    # Pre-create compressor for the primary head_dim
    self._key_compressors[head_dim] = TurboQuantCompressorMSE(
        head_dim, resolved_k, seed=seed
    )
    self._value_compressors[head_dim] = TurboQuantCompressorMSE(
        head_dim, resolved_v, seed=seed
    )

    self._compressed_keys: list[_CompressedLayer | None] = []
    self._compressed_values: list[_CompressedLayer | None] = []
    self._decompressed_k: list[torch.Tensor | None] = []
    self._decompressed_v: list[torch.Tensor | None] = []
    self._original_dtype: torch.dtype = torch.bfloat16
    self.fused_mode = False

    # Detect double-compression: if cache.update is already a bound method
    # on one of our wrapper classes, the cache is already wrapped.
    if hasattr(cache.update, "__self__") and isinstance(
        cache.update.__self__, (CompressedDynamicCache, TurboQuantKVCache)
    ):
        warnings.warn(
            "Cache is already wrapped by TurboQuant. "
            "Call restore() on the existing wrapper first.",
            UserWarning,
            stacklevel=2,
        )

    # Sliding window quality: in models with mixed sliding/full
    # attention (Gemma 3/4), full attention layers are the sole
    # retrieval path for tokens beyond the window. TQ quantization
    # error on these critical layers destroys NIAH at 2K+ (33% vs
    # 100% baseline). Bypass compression on full attention layers
    # to preserve retrieval quality. SWA layers may also be bypassed
    # by is_sliding when cache layers are pre-populated (e.g. tests),
    # but in production (lazy init) SWA layers compress normally.
    layer_types = getattr(model_config, "layer_types", None)
    self._full_attn_bypass: set[int] = set()
    if layer_types and any("sliding" in lt for lt in layer_types):
        self._full_attn_bypass = {
            i for i, lt in enumerate(layer_types) if "full" in lt
        }

    # SWA detection: warn when config has mixed attention layers but
    # cache was created without config (no SWA layer metadata).
    # Checks layer_types (not sliding_window alone) to avoid false
    # positives on Mistral-style uniform SWA configs.
    if (
        model_config is not None
        and layer_types
        and any("sliding" in lt for lt in layer_types)
        and not any(getattr(layer, "is_sliding", False) for layer in cache.layers)
    ):
        warnings.warn(
            "Cache appears to lack sliding window layer metadata. "
            "Create cache with `DynamicCache(config=model.config)` "
            "for correct Gemma support.",
            UserWarning,
            stacklevel=2,
        )

    # Patch cache methods
    self._original_update = cache.update
    self._original_get_seq_length = cache.get_seq_length
    cache.update = self._compressed_update
    cache.get_seq_length = self._compressed_get_seq_length
Attributes
key_compressor property
key_compressor: TurboQuantCompressorMSE

Primary key compressor (backward compat).

value_compressor property
value_compressor: TurboQuantCompressorMSE

Primary value compressor (backward compat).

rotation property
rotation: Tensor

Shared orthogonal rotation matrix [head_dim, head_dim] fp32.

K and V use the same rotation (same seed).

Returns:

Type Description
Tensor

The rotation matrix from the key compressor's quantizer.

centroids property
centroids: Tensor

Shared Lloyd-Max codebook [2^bits] fp32.

Returns:

Type Description
Tensor

Centroid values from the key compressor's quantizer.

Functions
get_compressed
get_compressed(layer_idx: int) -> tuple[Tensor, Tensor, Tensor, Tensor]

Return compressed K and V for a layer (fused kernel API).

Provides the raw nibble-packed indices and norms without dequantization, for use by the fused TQ4 Flash Attention kernel.

Parameters:

Name Type Description Default
layer_idx int

Transformer layer index.

required

Returns:

Type Description
Tensor

(k_packed, k_norms, v_packed, v_norms) where packed tensors

Tensor

are uint8 [batch, heads, seq, head_dim//2] and norms are

Tensor

fp32 [batch, heads, seq, 1].

Raises:

Type Description
ValueError

If layer_idx refers to a layer with no compressed data (not yet updated, or SWA-bypassed).

Source code in src/turboquant_vllm/kv_cache.py
def get_compressed(
    self, layer_idx: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Return compressed K and V for a layer (fused kernel API).

    Provides the raw nibble-packed indices and norms without
    dequantization, for use by the fused TQ4 Flash Attention kernel.

    Args:
        layer_idx: Transformer layer index.

    Returns:
        ``(k_packed, k_norms, v_packed, v_norms)`` where packed tensors
        are uint8 ``[batch, heads, seq, head_dim//2]`` and norms are
        fp32 ``[batch, heads, seq, 1]``.

    Raises:
        ValueError: If ``layer_idx`` refers to a layer with no
            compressed data (not yet updated, or SWA-bypassed).
    """
    if layer_idx >= len(self._compressed_keys):
        msg = f"Layer {layer_idx} has no compressed data (not yet updated)"
        raise ValueError(msg)
    if self._compressed_keys[layer_idx] is None:
        msg = f"Layer {layer_idx} has no compressed data (SWA-bypassed layer)"
        raise ValueError(msg)
    k = self._compressed_keys[layer_idx]
    v = self._compressed_values[layer_idx]
    assert k is not None  # guaranteed by the guard above
    assert v is not None
    return k.indices, k.norms, v.indices, v.norms
disable
disable() -> None

Disable compression, passing through to original update.

Source code in src/turboquant_vllm/kv_cache.py
def disable(self) -> None:
    """Disable compression, passing through to original update."""
    self.enabled = False
enable
enable() -> None

Re-enable compression after disable().

Source code in src/turboquant_vllm/kv_cache.py
def enable(self) -> None:
    """Re-enable compression after disable()."""
    self.enabled = True
restore
restore() -> None

Restore original methods on the wrapped cache.

Call this to fully unwrap the cache and remove all TurboQuant interception.

Source code in src/turboquant_vllm/kv_cache.py
def restore(self) -> None:
    """Restore original methods on the wrapped cache.

    Call this to fully unwrap the cache and remove all TurboQuant
    interception.
    """
    self.cache.update = self._original_update
    self.cache.get_seq_length = self._original_get_seq_length
__enter__
__enter__() -> CompressedDynamicCache

Enter the context manager.

Returns:

Type Description
CompressedDynamicCache

Self, for use in with ... as bindings.

Source code in src/turboquant_vllm/kv_cache.py
def __enter__(self) -> CompressedDynamicCache:
    """Enter the context manager.

    Returns:
        Self, for use in ``with ... as`` bindings.
    """
    return self
__exit__
__exit__(*exc: object) -> bool

Exit the context manager, restoring the original cache methods.

Returns:

Type Description
bool

False — exceptions are never suppressed.

Source code in src/turboquant_vllm/kv_cache.py
def __exit__(self, *exc: object) -> bool:
    """Exit the context manager, restoring the original cache methods.

    Returns:
        False — exceptions are never suppressed.
    """
    self.restore()
    return False
vram_bytes
vram_bytes() -> int

Calculate total VRAM used by compressed storage.

SWA-bypassed layers (None entries) are excluded from the total.

Returns:

Type Description
int

Total bytes across all compressed layers (keys + values).

Source code in src/turboquant_vllm/kv_cache.py
def vram_bytes(self) -> int:
    """Calculate total VRAM used by compressed storage.

    SWA-bypassed layers (None entries) are excluded from the total.

    Returns:
        Total bytes across all compressed layers (keys + values).
    """
    total = 0
    for layer in [*self._compressed_keys, *self._compressed_values]:
        if layer is None:
            continue
        total += layer.indices.nelement() * layer.indices.element_size()
        total += layer.norms.nelement() * layer.norms.element_size()
    return total
baseline_vram_bytes
baseline_vram_bytes() -> int

Estimate FP16 VRAM that would be used without compression.

Accounts for nibble-packed indices by doubling the last dimension to recover the original head_dim. SWA-bypassed layers (None entries) are excluded.

Returns:

Type Description
int

Total bytes if keys and values were stored as FP16 tensors.

Source code in src/turboquant_vllm/kv_cache.py
def baseline_vram_bytes(self) -> int:
    """Estimate FP16 VRAM that would be used without compression.

    Accounts for nibble-packed indices by doubling the last
    dimension to recover the original head_dim. SWA-bypassed layers
    (None entries) are excluded.

    Returns:
        Total bytes if keys and values were stored as FP16 tensors.
    """
    total = 0
    for layer in [*self._compressed_keys, *self._compressed_values]:
        if layer is None:
            continue
        b, h, s, d = layer.indices.shape
        # Nibble-packed indices have d = head_dim // 2
        if layer.packed:
            d = d * 2
        total += b * h * s * d * 2  # FP16 = 2 bytes per element
    return total
compression_stats
compression_stats() -> dict[str, Any]

Return compression statistics for reporting.

Reports per-component bit-widths, the true head_dim, compression ratio, and per-sequence VRAM estimates at representative context lengths (4K, 16K, 32K tokens). Only counts compressed (non-SWA) layers. VRAM estimates are per sequence — multiply by batch size for total memory.

Returns:

Type Description
dict[str, Any]

Dict with layer count, sequence length, per-component bit-widths,

dict[str, Any]

compressed/baseline sizes in MiB, compression ratio, VRAM savings,

dict[str, Any]

and per-sequence VRAM estimates at representative context lengths.

Source code in src/turboquant_vllm/kv_cache.py
def compression_stats(self) -> dict[str, Any]:
    """Return compression statistics for reporting.

    Reports per-component bit-widths, the true ``head_dim``, compression
    ratio, and per-sequence VRAM estimates at representative context
    lengths (4K, 16K, 32K tokens). Only counts compressed (non-SWA)
    layers.  VRAM estimates are per sequence — multiply by batch size
    for total memory.

    Returns:
        Dict with layer count, sequence length, per-component bit-widths,
        compressed/baseline sizes in MiB, compression ratio, VRAM savings,
        and per-sequence VRAM estimates at representative context lengths.
    """
    compressed_layers = [ck for ck in self._compressed_keys if ck is not None]
    if not compressed_layers:
        return {}

    compressed_bytes = self.vram_bytes()
    baseline_bytes = self.baseline_vram_bytes()
    ratio = baseline_bytes / compressed_bytes if compressed_bytes > 0 else 0.0

    layer = compressed_layers[0]
    b, h, s, _ = layer.indices.shape
    num_layers = len(compressed_layers)

    # Per-token-per-head byte cost for each component
    k_bytes_per_th = _packed_size(self.k_bits, self.head_dim) + 4  # indices + norm
    v_bytes_per_th = _packed_size(self.v_bits, self.head_dim) + 4
    bytes_per_token = num_layers * h * (k_bytes_per_th + v_bytes_per_th)

    # VRAM estimates at representative context lengths (per sequence —
    # multiply by batch_size for total VRAM).
    vram_estimate: dict[int, float] = {}
    for ctx_len in (4096, 16384, 32768):
        vram_estimate[ctx_len] = round(bytes_per_token * ctx_len / (1024 * 1024), 2)

    return {
        "num_layers": num_layers,
        "seq_len": s,
        "batch_size": b,
        "num_heads": h,
        "head_dim": self.head_dim,
        "bits": self.bits,
        "k_bits": self.k_bits,
        "v_bits": self.v_bits,
        "k_nibble_packed": self._k_nibble_packed,
        "v_nibble_packed": self._v_nibble_packed,
        "compressed_mib": compressed_bytes / (1024 * 1024),
        "baseline_mib": baseline_bytes / (1024 * 1024),
        "compression_ratio": round(ratio, 2),
        "savings_mib": (baseline_bytes - compressed_bytes) / (1024 * 1024),
        "vram_estimate": vram_estimate,
    }