Skip to content

turboquant_vllm

turboquant_vllm

TurboQuant KV cache compression for consumer GPUs.

Implements Google's TurboQuant algorithm (ICLR 2026) for compressing transformer key-value caches to 3-4 bits per coordinate with near-zero accuracy loss. Designed for benchmarking on consumer hardware (RTX 4090).

Reference: arXiv 2504.19874 — "TurboQuant: Online Vector Quantization with Near-optimal Distortion Rate"

Attributes:

Name Type Description
CompressedDynamicCache

KV cache with real VRAM savings (uint8 + fp32).

TurboQuantKVCache

Accuracy-only KV cache wrapper (no VRAM savings).

TurboQuantCompressorMSE

Value cache compressor (MSE-optimal).

TurboQuantCompressorV2

Key cache compressor (QJL-corrected).

TurboQuantMSE

Stage 1 quantizer (rotation + Lloyd-Max).

TurboQuantProd

Stage 1 + 2 quantizer (MSE + QJL).

LloydMaxCodebook

Precomputed optimal scalar quantizer.

solve_lloyd_max tuple[Tensor, Tensor]

Factory for Lloyd-Max codebooks (cached).

Examples:

from turboquant_vllm import TurboQuantKVCache

wrapper = TurboQuantKVCache(cache, head_dim=128, bits=3)
See Also

:mod:turboquant_vllm.benchmark: CLI harness for benchmarking. :mod:turboquant_vllm.lloyd_max: Lloyd-Max codebook solver.

Classes

TurboQuantCompressorMSE

TurboQuantCompressorMSE(head_dim: int, bits: int = 3, *, seed: int = 42)

Value cache compressor with MSE-optimal reconstruction.

Uses Stage 1 only (TurboQuantMSE) for value vectors. Values appear in the softmax(scores) @ V multiplication where reconstruction quality matters but inner-product structure does not.

Attributes:

Name Type Description
quantizer TurboQuantMSE

TurboQuantMSE instance.

bits int

Bits per coordinate.

head_dim int

Model head dimension.

Examples:

Compress and reconstruct value tensors:

comp = TurboQuantCompressorMSE(head_dim=128, bits=3)
compressed = comp.compress(value_states)
reconstructed = comp.decompress(compressed)

Initialize the value compressor.

Parameters:

Name Type Description Default
head_dim int

Dimension of each attention head.

required
bits int

Bits per coordinate (default 3).

3
seed int

Random seed for reproducibility.

42
Source code in src/turboquant_vllm/compressors.py
def __init__(self, head_dim: int, bits: int = 3, *, seed: int = 42) -> None:
    """Initialize the value compressor.

    Args:
        head_dim: Dimension of each attention head.
        bits: Bits per coordinate (default 3).
        seed: Random seed for reproducibility.
    """
    self.head_dim = head_dim
    self.bits = bits
    self.quantizer = TurboQuantMSE(head_dim, bits, seed=seed)
Functions
compress
compress(values: Tensor) -> CompressedValues

Compress value tensors.

Parameters:

Name Type Description Default
values Tensor

Value tensor of shape (batch, heads, seq_len, head_dim).

required

Returns:

Type Description
CompressedValues

CompressedValues containing indices and norms.

Source code in src/turboquant_vllm/compressors.py
def compress(self, values: torch.Tensor) -> CompressedValues:
    """Compress value tensors.

    Args:
        values: Value tensor of shape (batch, heads, seq_len, head_dim).

    Returns:
        CompressedValues containing indices and norms.
    """
    original_dtype = values.dtype
    indices, norms = self.quantizer.quantize(values.float())
    return CompressedValues(
        indices=indices,
        norms=norms,
        original_dtype=original_dtype,
    )
decompress
decompress(compressed: CompressedValues) -> Tensor

Reconstruct value tensors from compressed representation.

Parameters:

Name Type Description Default
compressed CompressedValues

CompressedValues from compress().

required

Returns:

Type Description
Tensor

Reconstructed value tensor in the original dtype.

Source code in src/turboquant_vllm/compressors.py
def decompress(self, compressed: CompressedValues) -> torch.Tensor:
    """Reconstruct value tensors from compressed representation.

    Args:
        compressed: CompressedValues from compress().

    Returns:
        Reconstructed value tensor in the original dtype.
    """
    result = self.quantizer.dequantize(compressed.indices, compressed.norms)
    return result.to(compressed.original_dtype)

TurboQuantCompressorV2

TurboQuantCompressorV2(head_dim: int, bits: int = 3, *, seed: int = 42)

Key cache compressor with unbiased attention score estimation.

Uses the full two-stage TurboQuantProd algorithm to compress key vectors while preserving accurate inner product estimation for attention computation (Q·K^T).

Attributes:

Name Type Description
quantizer TurboQuantProd

Two-stage TurboQuantProd instance.

bits int

Total bit budget per coordinate.

head_dim int

Model head dimension.

Examples:

Compress keys and compute attention scores directly:

comp = TurboQuantCompressorV2(head_dim=128, bits=3)
compressed = comp.compress(key_states)
scores = comp.asymmetric_attention_scores(query, compressed)

Initialize the key compressor.

Parameters:

Name Type Description Default
head_dim int

Dimension of each attention head.

required
bits int

Total bits per coordinate (default 3).

3
seed int

Random seed for reproducibility.

42
Source code in src/turboquant_vllm/compressors.py
def __init__(self, head_dim: int, bits: int = 3, *, seed: int = 42) -> None:
    """Initialize the key compressor.

    Args:
        head_dim: Dimension of each attention head.
        bits: Total bits per coordinate (default 3).
        seed: Random seed for reproducibility.
    """
    self.head_dim = head_dim
    self.bits = bits
    self.quantizer = TurboQuantProd(head_dim, bits, seed=seed)
Functions
compress
compress(keys: Tensor) -> CompressedKeys

Compress key tensors.

Parameters:

Name Type Description Default
keys Tensor

Key tensor of shape (batch, heads, seq_len, head_dim).

required

Returns:

Type Description
CompressedKeys

CompressedKeys containing all components for attention estimation.

Source code in src/turboquant_vllm/compressors.py
def compress(self, keys: torch.Tensor) -> CompressedKeys:
    """Compress key tensors.

    Args:
        keys: Key tensor of shape (batch, heads, seq_len, head_dim).

    Returns:
        CompressedKeys containing all components for attention estimation.
    """
    original_dtype = keys.dtype
    indices, norms, qjl_signs, residual_norms = self.quantizer.quantize(
        keys.float()
    )
    return CompressedKeys(
        indices=indices,
        norms=norms,
        qjl_signs=qjl_signs,
        residual_norms=residual_norms,
        original_dtype=original_dtype,
    )
decompress
decompress(compressed: CompressedKeys) -> Tensor

Reconstruct key tensors from compressed representation.

Note: For attention, prefer asymmetric_attention_scores() which uses the QJL-corrected inner product estimator for better accuracy.

Parameters:

Name Type Description Default
compressed CompressedKeys

CompressedKeys from compress().

required

Returns:

Type Description
Tensor

Reconstructed key tensor in the original dtype.

Source code in src/turboquant_vllm/compressors.py
def decompress(self, compressed: CompressedKeys) -> torch.Tensor:
    """Reconstruct key tensors from compressed representation.

    Note: For attention, prefer ``asymmetric_attention_scores()`` which
    uses the QJL-corrected inner product estimator for better accuracy.

    Args:
        compressed: CompressedKeys from compress().

    Returns:
        Reconstructed key tensor in the original dtype.
    """
    result = self.quantizer.dequantize(
        compressed.indices,
        compressed.norms,
        compressed.qjl_signs,
        compressed.residual_norms,
    )
    return result.to(compressed.original_dtype)
asymmetric_attention_scores
asymmetric_attention_scores(query: Tensor, compressed: CompressedKeys) -> Tensor

Compute attention scores directly from compressed keys.

Uses the unbiased two-stage inner product estimator rather than decompressing keys and computing standard dot products. This is both more memory-efficient and more accurate.

.. warning:: MEMORY SCALING

The current implementation expands tensors to
(batch, heads, q_len, kv_len, dim) for broadcasting.
This allocates ~5 intermediate tensors at that shape.
For real sequence lengths (kv_len=6144, heads=32, dim=128)
this would use ~500MB+ per call. Suitable for correctness
testing on short sequences only.

TODO: Replace with a chunked or fused Triton kernel for
production use at real sequence lengths.

Parameters:

Name Type Description Default
query Tensor

Query tensor, shape (batch, heads, q_len, head_dim).

required
compressed CompressedKeys

CompressedKeys from compress().

required

Returns:

Type Description
Tensor

Attention logits, shape (batch, heads, q_len, kv_len).

Source code in src/turboquant_vllm/compressors.py
def asymmetric_attention_scores(
    self, query: torch.Tensor, compressed: CompressedKeys
) -> torch.Tensor:
    """Compute attention scores directly from compressed keys.

    Uses the unbiased two-stage inner product estimator rather than
    decompressing keys and computing standard dot products. This is
    both more memory-efficient and more accurate.

    .. warning:: MEMORY SCALING

        The current implementation expands tensors to
        (batch, heads, q_len, kv_len, dim) for broadcasting.
        This allocates ~5 intermediate tensors at that shape.
        For real sequence lengths (kv_len=6144, heads=32, dim=128)
        this would use ~500MB+ per call. Suitable for correctness
        testing on short sequences only.

        TODO: Replace with a chunked or fused Triton kernel for
        production use at real sequence lengths.

    Args:
        query: Query tensor, shape (batch, heads, q_len, head_dim).
        compressed: CompressedKeys from compress().

    Returns:
        Attention logits, shape (batch, heads, q_len, kv_len).
    """
    b, h, q_len, d = query.shape
    _, _, kv_len, _ = compressed.indices.shape

    # Expand query for broadcasting: (b, h, q_len, 1, d)
    # NOTE: This expand pattern is O(q_len * kv_len * dim) memory.
    # Fine for benchmarking short sequences, not for production.
    q_exp = query.float().unsqueeze(3).expand(b, h, q_len, kv_len, d)
    # Expand compressed key components: (b, h, 1, kv_len, ...)
    idx_exp = compressed.indices.unsqueeze(2).expand(b, h, q_len, kv_len, d)
    n_exp = compressed.norms.unsqueeze(2).expand(b, h, q_len, kv_len, 1)
    qjl_exp = compressed.qjl_signs.unsqueeze(2).expand(
        b, h, q_len, kv_len, self.quantizer.qjl_dim
    )
    rn_exp = compressed.residual_norms.unsqueeze(2).expand(b, h, q_len, kv_len, 1)

    scores = self.quantizer.estimate_inner_product(
        q_exp, idx_exp, n_exp, qjl_exp, rn_exp
    )
    return scores.squeeze(-1).to(query.dtype)

CompressedDynamicCache

CompressedDynamicCache(cache: Any, head_dim: int, bits: int = 3, *, seed: int = 42)

KV cache with real VRAM savings via compressed index storage.

Stores TurboQuant-compressed representations and dequantizes lazily on each cache read. Only one layer's decompressed tensors are held in memory at a time — previous layers are freed on the next update.

Storage per token per head (head_dim=128):

============ ======= ===== =========== =========== Mode Dtype Bytes Compression Quality ============ ======= ===== =========== =========== FP16 baseline fp16 256 1.0x — TQ3 (3-bit) uint8 132 1.94x ~95% cosine TQ4 (4-bit) nibble 68 3.76x ~97% cosine ============ ======= ===== =========== ===========

At bits=4, indices are nibble-packed (two 4-bit values per byte), nearly doubling compression over TQ3 with better quality. Float32 norms are required — fp16 causes output degradation at 10K+ token sequences due to accumulated precision loss.

Integration strategy: non-invasive method replacement (same pattern as TurboQuantKVCache). Patches update() and get_seq_length() on the wrapped DynamicCache.

Attributes:

Name Type Description
cache Any

The wrapped DynamicCache instance.

key_compressor TurboQuantCompressorMSE

Compressor for key tensors.

value_compressor TurboQuantCompressorMSE

Compressor for value tensors.

bits int

Quantization bits per coordinate.

head_dim int

Model head dimension.

enabled bool

Whether compression is active.

fused_mode bool

When True, skip decompression in update() (fused kernel reads compressed data via get_compressed()).

rotation Tensor

Shared rotation matrix [head_dim, head_dim].

centroids Tensor

Shared codebook [2^bits].

Examples:

from transformers import DynamicCache

cache = DynamicCache()
compressed = CompressedDynamicCache(cache, head_dim=128, bits=4)
compressed.vram_bytes()  # 0

Initialize the compressed KV cache wrapper.

Sets up compressors, internal storage for compressed representations, and incremental decompressed buffers. fused_mode starts disabled.

Parameters:

Name Type Description Default
cache Any

A HuggingFace DynamicCache instance to wrap.

required
head_dim int

Dimension of each attention head. Must be even when bits=4 (required for nibble packing pairs).

required
bits int

Quantization bits per coordinate (default 3). Use 4 for nibble-packed storage (3.76x compression).

3
seed int

Random seed for reproducibility.

42

Raises:

Type Description
ValueError

If bits=4 and head_dim is odd.

Source code in src/turboquant_vllm/kv_cache.py
def __init__(
    self,
    cache: Any,
    head_dim: int,
    bits: int = 3,
    *,
    seed: int = 42,
) -> None:
    """Initialize the compressed KV cache wrapper.

    Sets up compressors, internal storage for compressed representations,
    and incremental decompressed buffers. ``fused_mode`` starts disabled.

    Args:
        cache: A HuggingFace DynamicCache instance to wrap.
        head_dim: Dimension of each attention head. Must be even
            when ``bits=4`` (required for nibble packing pairs).
        bits: Quantization bits per coordinate (default 3). Use 4
            for nibble-packed storage (3.76x compression).
        seed: Random seed for reproducibility.

    Raises:
        ValueError: If ``bits=4`` and ``head_dim`` is odd.
    """
    if bits == 4 and head_dim % 2 != 0:
        msg = f"bits=4 requires even head_dim for nibble packing, got {head_dim}"
        raise ValueError(msg)

    self.cache = cache
    self.head_dim = head_dim
    self.bits = bits
    self._nibble_packed = bits == 4
    self.enabled = True

    self.key_compressor = TurboQuantCompressorMSE(head_dim, bits, seed=seed)
    self.value_compressor = TurboQuantCompressorMSE(head_dim, bits, seed=seed)

    self._compressed_keys: list[_CompressedLayer] = []
    self._compressed_values: list[_CompressedLayer] = []
    self._decompressed_k: list[torch.Tensor | None] = []
    self._decompressed_v: list[torch.Tensor | None] = []
    self._original_dtype: torch.dtype = torch.bfloat16
    self.fused_mode = False

    # Patch cache methods
    self._original_update = cache.update
    self._original_get_seq_length = cache.get_seq_length
    cache.update = self._compressed_update
    cache.get_seq_length = self._compressed_get_seq_length
Attributes
rotation property
rotation: Tensor

Shared orthogonal rotation matrix [head_dim, head_dim] fp32.

K and V use the same rotation (same seed).

Returns:

Type Description
Tensor

The rotation matrix from the key compressor's quantizer.

centroids property
centroids: Tensor

Shared Lloyd-Max codebook [2^bits] fp32.

Returns:

Type Description
Tensor

Centroid values from the key compressor's quantizer.

Functions
get_compressed
get_compressed(layer_idx: int) -> tuple[Tensor, Tensor, Tensor, Tensor]

Return compressed K and V for a layer (fused kernel API).

Provides the raw nibble-packed indices and norms without dequantization, for use by the fused TQ4 Flash Attention kernel.

Parameters:

Name Type Description Default
layer_idx int

Transformer layer index.

required

Returns:

Type Description
Tensor

(k_packed, k_norms, v_packed, v_norms) where packed tensors

Tensor

are uint8 [batch, heads, seq, head_dim//2] and norms are

Tensor

fp32 [batch, heads, seq, 1].

Source code in src/turboquant_vllm/kv_cache.py
def get_compressed(
    self, layer_idx: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Return compressed K and V for a layer (fused kernel API).

    Provides the raw nibble-packed indices and norms without
    dequantization, for use by the fused TQ4 Flash Attention kernel.

    Args:
        layer_idx: Transformer layer index.

    Returns:
        ``(k_packed, k_norms, v_packed, v_norms)`` where packed tensors
        are uint8 ``[batch, heads, seq, head_dim//2]`` and norms are
        fp32 ``[batch, heads, seq, 1]``.
    """
    k = self._compressed_keys[layer_idx]
    v = self._compressed_values[layer_idx]
    return k.indices, k.norms, v.indices, v.norms
disable
disable() -> None

Disable compression, passing through to original update.

Source code in src/turboquant_vllm/kv_cache.py
def disable(self) -> None:
    """Disable compression, passing through to original update."""
    self.enabled = False
enable
enable() -> None

Re-enable compression after disable().

Source code in src/turboquant_vllm/kv_cache.py
def enable(self) -> None:
    """Re-enable compression after disable()."""
    self.enabled = True
restore
restore() -> None

Restore original methods on the wrapped cache.

Call this to fully unwrap the cache and remove all TurboQuant interception.

Source code in src/turboquant_vllm/kv_cache.py
def restore(self) -> None:
    """Restore original methods on the wrapped cache.

    Call this to fully unwrap the cache and remove all TurboQuant
    interception.
    """
    self.cache.update = self._original_update
    self.cache.get_seq_length = self._original_get_seq_length
vram_bytes
vram_bytes() -> int

Calculate total VRAM used by compressed storage.

Returns:

Type Description
int

Total bytes across all compressed layers (keys + values).

Source code in src/turboquant_vllm/kv_cache.py
def vram_bytes(self) -> int:
    """Calculate total VRAM used by compressed storage.

    Returns:
        Total bytes across all compressed layers (keys + values).
    """
    total = 0
    for layer in [*self._compressed_keys, *self._compressed_values]:
        total += layer.indices.nelement() * layer.indices.element_size()
        total += layer.norms.nelement() * layer.norms.element_size()
    return total
baseline_vram_bytes
baseline_vram_bytes() -> int

Estimate FP16 VRAM that would be used without compression.

Accounts for nibble-packed indices by doubling the last dimension to recover the original head_dim.

Returns:

Type Description
int

Total bytes if keys and values were stored as FP16 tensors.

Source code in src/turboquant_vllm/kv_cache.py
def baseline_vram_bytes(self) -> int:
    """Estimate FP16 VRAM that would be used without compression.

    Accounts for nibble-packed indices by doubling the last
    dimension to recover the original head_dim.

    Returns:
        Total bytes if keys and values were stored as FP16 tensors.
    """
    total = 0
    for layer in [*self._compressed_keys, *self._compressed_values]:
        b, h, s, d = layer.indices.shape
        # Nibble-packed indices have d = head_dim // 2
        if layer.packed:
            d = d * 2
        total += b * h * s * d * 2  # FP16 = 2 bytes per element
    return total
compression_stats
compression_stats() -> dict[str, Any]

Return compression statistics for reporting.

Reports the true head_dim (not the packed index dimension) and includes a nibble_packed flag.

Returns:

Type Description
dict[str, Any]

Dict with layer count, sequence length, compressed/baseline

dict[str, Any]

sizes in MiB, compression ratio, packing mode, and VRAM savings.

Source code in src/turboquant_vllm/kv_cache.py
def compression_stats(self) -> dict[str, Any]:
    """Return compression statistics for reporting.

    Reports the true ``head_dim`` (not the packed index dimension)
    and includes a ``nibble_packed`` flag.

    Returns:
        Dict with layer count, sequence length, compressed/baseline
        sizes in MiB, compression ratio, packing mode, and VRAM savings.
    """
    if not self._compressed_keys:
        return {}

    compressed_bytes = self.vram_bytes()
    baseline_bytes = self.baseline_vram_bytes()
    ratio = baseline_bytes / compressed_bytes if compressed_bytes > 0 else 0.0

    layer = self._compressed_keys[0]
    b, h, s, _ = layer.indices.shape

    return {
        "num_layers": len(self._compressed_keys),
        "seq_len": s,
        "batch_size": b,
        "num_heads": h,
        "head_dim": self.head_dim,
        "bits": self.bits,
        "nibble_packed": self._nibble_packed,
        "compressed_mib": compressed_bytes / (1024 * 1024),
        "baseline_mib": baseline_bytes / (1024 * 1024),
        "compression_ratio": round(ratio, 2),
        "savings_mib": (baseline_bytes - compressed_bytes) / (1024 * 1024),
    }

TurboQuantKVCache

TurboQuantKVCache(
    cache: Any,
    head_dim: int,
    bits: int = 3,
    *,
    seed: int = 42,
    compress_keys: bool = True,
    compress_values: bool = True,
)

Transparent KV cache compression wrapper (drop-in mode).

Intercepts cache updates to compress key/value tensors before they are stored. Both keys and values use TurboQuantCompressorMSE (full MSE-optimal quantization at the configured bit-width).

This is the "drop-in" approach where standard attention (Q @ K^T) operates on decompressed keys. For the QJL-corrected inner product path (TurboQuantProd), a custom attention kernel would be needed — see TurboQuantCompressorV2.asymmetric_attention_scores().

Attributes:

Name Type Description
cache Any

The wrapped DynamicCache instance.

key_compressor TurboQuantCompressorMSE

Compressor for key tensors.

value_compressor TurboQuantCompressorMSE

Compressor for value tensors.

bits int

Quantization bits per coordinate.

head_dim int

Model head dimension.

enabled bool

Whether compression is active.

Examples:

from transformers import DynamicCache

cache = DynamicCache()
tq = TurboQuantKVCache(cache, head_dim=128, bits=3)
tq.enabled  # True

Initialize the TurboQuant KV cache wrapper.

Parameters:

Name Type Description Default
cache Any

A HuggingFace DynamicCache instance to wrap.

required
head_dim int

Dimension of each attention head.

required
bits int

Quantization bits per coordinate (default 3).

3
seed int

Random seed for reproducibility.

42
compress_keys bool

Whether to compress key tensors.

True
compress_values bool

Whether to compress value tensors.

True
Source code in src/turboquant_vllm/kv_cache.py
def __init__(
    self,
    cache: Any,
    head_dim: int,
    bits: int = 3,
    *,
    seed: int = 42,
    compress_keys: bool = True,
    compress_values: bool = True,
) -> None:
    """Initialize the TurboQuant KV cache wrapper.

    Args:
        cache: A HuggingFace DynamicCache instance to wrap.
        head_dim: Dimension of each attention head.
        bits: Quantization bits per coordinate (default 3).
        seed: Random seed for reproducibility.
        compress_keys: Whether to compress key tensors.
        compress_values: Whether to compress value tensors.
    """
    self.cache = cache
    self.head_dim = head_dim
    self.bits = bits
    self.compress_keys = compress_keys
    self.compress_values = compress_values
    self.enabled = True

    # Drop-in mode: use MSE-only for BOTH keys and values.
    # TurboQuantCompressorV2 (TurboQuantProd) allocates 1 bit to QJL correction,
    # but QJL only helps when attention calls estimate_inner_product() directly.
    # Standard attention does Q @ K^T on decompressed keys, so QJL is invisible
    # and we lose 1 bit of MSE resolution for nothing. Full 3-bit MSE gives
    # ~95% cosine sim vs ~87% with 2-bit MSE + 1-bit QJL.
    # See: https://dejan.ai/blog/turboquant/ (TurboQuant_mse for drop-in cache)
    self.key_compressor = TurboQuantCompressorMSE(head_dim, bits, seed=seed)
    self.value_compressor = TurboQuantCompressorMSE(head_dim, bits, seed=seed)

    # Patch the cache's update method
    self._original_update = cache.update
    cache.update = self._compressed_update
Functions
disable
disable() -> None

Disable compression, passing through to original update.

Useful for A/B benchmarking within the same run.

Source code in src/turboquant_vllm/kv_cache.py
def disable(self) -> None:
    """Disable compression, passing through to original update.

    Useful for A/B benchmarking within the same run.
    """
    self.enabled = False
enable
enable() -> None

Re-enable compression after disable().

Source code in src/turboquant_vllm/kv_cache.py
def enable(self) -> None:
    """Re-enable compression after disable()."""
    self.enabled = True
restore
restore() -> None

Restore the original update method on the wrapped cache.

Call this to fully unwrap the cache and remove all TurboQuant interception.

Source code in src/turboquant_vllm/kv_cache.py
def restore(self) -> None:
    """Restore the original update method on the wrapped cache.

    Call this to fully unwrap the cache and remove all TurboQuant
    interception.
    """
    self.cache.update = self._original_update

LloydMaxCodebook dataclass

LloydMaxCodebook(centroids: Tensor, boundaries: Tensor, bits: int, dim: int)

Precomputed optimal scalar quantizer for a given dimension and bit-width.

The codebook stores centroids and boundaries computed by the Lloyd-Max algorithm. It maps continuous coordinate values to discrete indices and back via nearest-centroid lookup.

Attributes:

Name Type Description
centroids Tensor

Reconstruction values, shape (2^bits,).

boundaries Tensor

Partition boundaries, shape (2^bits - 1,).

bits int

Number of quantization bits.

dim int

Vector dimension used to compute the codebook.

Examples:

Round-trip quantize and dequantize a tensor:

codebook = LloydMaxCodebook(centroids, boundaries, bits=3, dim=128)
indices = codebook.quantize(x)
x_hat = codebook.dequantize(indices)
Functions
quantize
quantize(x: Tensor) -> Tensor

Map continuous values to nearest centroid indices.

Uses bucket search on partition boundaries for O(log n) lookup.

Parameters:

Name Type Description Default
x Tensor

Input tensor of any shape.

required

Returns:

Type Description
Tensor

Integer tensor of same shape with centroid indices in

Tensor

[0, 2^bits - 1].

Source code in src/turboquant_vllm/lloyd_max.py
def quantize(self, x: torch.Tensor) -> torch.Tensor:
    """Map continuous values to nearest centroid indices.

    Uses bucket search on partition boundaries for O(log n) lookup.

    Args:
        x: Input tensor of any shape.

    Returns:
        Integer tensor of same shape with centroid indices in
        [0, 2^bits - 1].
    """
    bounds = self.boundaries.to(x.device)
    # bucketize returns the index of the bucket each value falls into
    return torch.bucketize(x, bounds)
dequantize
dequantize(indices: Tensor) -> Tensor

Reconstruct continuous values from centroid indices.

Parameters:

Name Type Description Default
indices Tensor

Integer tensor of centroid indices.

required

Returns:

Type Description
Tensor

Float tensor of reconstructed values with same shape as indices.

Source code in src/turboquant_vllm/lloyd_max.py
def dequantize(self, indices: torch.Tensor) -> torch.Tensor:
    """Reconstruct continuous values from centroid indices.

    Args:
        indices: Integer tensor of centroid indices.

    Returns:
        Float tensor of reconstructed values with same shape as indices.
    """
    cents = self.centroids.to(indices.device)
    return cents[indices]

TurboQuantMSE

TurboQuantMSE(dim: int, bits: int, *, seed: int = 42)

Stage 1 quantizer: rotation + Lloyd-Max scalar quantization.

Achieves near-optimal MSE distortion rate for high-dimensional vectors by exploiting the concentrated Beta distribution that emerges after random rotation.

Attributes:

Name Type Description
dim int

Vector dimension.

bits int

Quantization bit-width.

codebook LloydMaxCodebook

Precomputed Lloyd-Max codebook.

rotation Tensor

Orthogonal rotation matrix, shape (dim, dim).

Examples:

quantizer = TurboQuantMSE(dim=64, bits=4)
indices, norms = quantizer.quantize(torch.randn(8, 64))
reconstructed = quantizer.dequantize(indices, norms)

Initialize the MSE quantizer.

Parameters:

Name Type Description Default
dim int

Vector dimension (head dimension of the model).

required
bits int

Quantization bits per coordinate (2-4 typical).

required
seed int

Random seed for the rotation matrix.

42
Source code in src/turboquant_vllm/quantizer.py
def __init__(self, dim: int, bits: int, *, seed: int = 42) -> None:
    """Initialize the MSE quantizer.

    Args:
        dim: Vector dimension (head dimension of the model).
        bits: Quantization bits per coordinate (2-4 typical).
        seed: Random seed for the rotation matrix.
    """
    self.dim = dim
    self.bits = bits
    centroids, boundaries = solve_lloyd_max(dim, bits)
    self.codebook = LloydMaxCodebook(
        centroids=centroids,
        boundaries=boundaries,
        bits=bits,
        dim=dim,
    )
    self.rotation = _generate_rotation_matrix(dim, seed=seed)
Functions
quantize
quantize(x: Tensor) -> tuple[Tensor, Tensor]

Quantize vectors to centroid indices.

Applies rotation, extracts norms, normalizes to unit sphere, then quantizes each coordinate independently.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (..., dim).

required

Returns:

Type Description
Tensor

Tuple of (indices, norms) where indices is a long tensor of

Tensor

shape (..., dim) and norms is a float tensor of shape (..., 1).

Source code in src/turboquant_vllm/quantizer.py
def quantize(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """Quantize vectors to centroid indices.

    Applies rotation, extracts norms, normalizes to unit sphere,
    then quantizes each coordinate independently.

    Args:
        x: Input tensor of shape (..., dim).

    Returns:
        Tuple of (indices, norms) where indices is a long tensor of
        shape (..., dim) and norms is a float tensor of shape (..., 1).
    """
    # Store original shape and flatten to 2D
    orig_shape = x.shape
    flat = x.reshape(-1, self.dim).float()

    # Extract and store norms
    norms = torch.norm(flat, dim=-1, keepdim=True)
    normalized = flat / (norms + 1e-10)

    # Apply rotation: y = x @ Pi^T
    pi = self.rotation.to(flat.device)
    rotated = normalized @ pi.T

    # Quantize each coordinate independently
    indices = self.codebook.quantize(rotated)

    return indices.reshape(orig_shape), norms.reshape(*orig_shape[:-1], 1)
dequantize
dequantize(indices: Tensor, norms: Tensor) -> Tensor

Reconstruct vectors from centroid indices and norms.

Looks up centroids, applies inverse rotation, and rescales by stored norms.

Parameters:

Name Type Description Default
indices Tensor

Long tensor of centroid indices, shape (..., dim).

required
norms Tensor

Float tensor of vector norms, shape (..., 1).

required

Returns:

Type Description
Tensor

Reconstructed float tensor of shape (..., dim).

Source code in src/turboquant_vllm/quantizer.py
def dequantize(self, indices: torch.Tensor, norms: torch.Tensor) -> torch.Tensor:
    """Reconstruct vectors from centroid indices and norms.

    Looks up centroids, applies inverse rotation, and rescales
    by stored norms.

    Args:
        indices: Long tensor of centroid indices, shape (..., dim).
        norms: Float tensor of vector norms, shape (..., 1).

    Returns:
        Reconstructed float tensor of shape (..., dim).
    """
    orig_shape = indices.shape
    flat_idx = indices.reshape(-1, self.dim)
    flat_norms = norms.reshape(-1, 1)

    # Lookup centroids
    reconstructed = self.codebook.dequantize(flat_idx)

    # Inverse rotation: x = y @ Pi
    pi = self.rotation.to(reconstructed.device)
    unrotated = reconstructed @ pi

    # Rescale by norms
    result = unrotated * flat_norms

    return result.reshape(orig_shape)

TurboQuantProd

TurboQuantProd(dim: int, bits: int, *, qjl_dim: int | None = None, seed: int = 42)

Two-stage quantizer with QJL correction for unbiased inner products.

Allocates (bits-1) bits to Lloyd-Max MSE quantization and 1 bit to Quantized Johnson-Lindenstrauss residual correction. The QJL step eliminates bias in dot-product estimation, which is critical for attention score computation (Q·K^T).

The unbiased estimator

~ + ||r|| * sqrt(pi/2) / m * <S@q, sign(S@r)>

where r is the quantization residual and S is a random Gaussian projection matrix.

Attributes:

Name Type Description
dim int

Vector dimension.

bits int

Total bit budget (bits-1 for MSE, 1 for QJL).

mse_quantizer TurboQuantMSE

Stage 1 quantizer with (bits-1) bits.

qjl_dim int

Number of QJL projection dimensions.

qjl_matrix Tensor

Random Gaussian projection matrix.

Examples:

quantizer = TurboQuantProd(dim=64, bits=4)
indices, norms, signs, res_norms = quantizer.quantize(torch.randn(8, 64))
scores = quantizer.estimate_inner_product(
    torch.randn(1, 64), indices, norms, signs, res_norms
)

Initialize the two-stage quantizer.

Parameters:

Name Type Description Default
dim int

Vector dimension (head dimension of the model).

required
bits int

Total bit budget per coordinate. Must be >= 2 (1 bit for MSE + 1 bit for QJL minimum).

required
qjl_dim int | None

Number of QJL projection dimensions. Defaults to dim (standard JL dimensionality).

None
seed int

Random seed for rotation and projection matrices.

42

Raises:

Type Description
ValueError

If bits < 2.

Source code in src/turboquant_vllm/quantizer.py
def __init__(
    self,
    dim: int,
    bits: int,
    *,
    qjl_dim: int | None = None,
    seed: int = 42,
) -> None:
    """Initialize the two-stage quantizer.

    Args:
        dim: Vector dimension (head dimension of the model).
        bits: Total bit budget per coordinate. Must be >= 2
            (1 bit for MSE + 1 bit for QJL minimum).
        qjl_dim: Number of QJL projection dimensions. Defaults
            to dim (standard JL dimensionality).
        seed: Random seed for rotation and projection matrices.

    Raises:
        ValueError: If bits < 2.
    """
    if bits < 2:
        msg = f"TurboQuantProd requires bits >= 2, got {bits}"
        raise ValueError(msg)

    self.dim = dim
    self.bits = bits
    self.mse_quantizer = TurboQuantMSE(dim, bits - 1, seed=seed)

    self.qjl_dim = qjl_dim or dim
    gen = torch.Generator().manual_seed(seed + 1)
    self.qjl_matrix = torch.randn(self.qjl_dim, dim, generator=gen) / math.sqrt(
        self.qjl_dim
    )
Functions
quantize
quantize(x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]

Quantize vectors with MSE + QJL correction.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (..., dim).

required

Returns:

Type Description
tuple[Tensor, Tensor, Tensor, Tensor]

Tuple of (indices, norms, qjl_signs, residual_norms): - indices: Lloyd-Max centroid indices, shape (..., dim) - norms: Vector norms, shape (..., 1) - qjl_signs: Sign bits of projected residuals, shape (..., qjl_dim) - residual_norms: Norms of quantization residuals, shape (..., 1)

Source code in src/turboquant_vllm/quantizer.py
def quantize(
    self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Quantize vectors with MSE + QJL correction.

    Args:
        x: Input tensor of shape (..., dim).

    Returns:
        Tuple of (indices, norms, qjl_signs, residual_norms):
            - indices: Lloyd-Max centroid indices, shape (..., dim)
            - norms: Vector norms, shape (..., 1)
            - qjl_signs: Sign bits of projected residuals, shape (..., qjl_dim)
            - residual_norms: Norms of quantization residuals, shape (..., 1)
    """
    # Stage 1: MSE quantization
    indices, norms = self.mse_quantizer.quantize(x)

    # Compute residual: r = x - dequant(quant(x))
    reconstructed = self.mse_quantizer.dequantize(indices, norms)
    residual = x.float() - reconstructed

    # Residual norms for scaling
    residual_norms = torch.norm(residual, dim=-1, keepdim=True)

    # Stage 2: QJL projection → store only signs
    s = self.qjl_matrix.to(x.device)
    projected = residual.reshape(-1, self.dim) @ s.T
    qjl_signs = torch.sign(projected).reshape(*x.shape[:-1], self.qjl_dim)
    # Replace zeros with +1 (ties go positive)
    qjl_signs[qjl_signs == 0] = 1.0

    return indices, norms, qjl_signs, residual_norms
dequantize
dequantize(
    indices: Tensor, norms: Tensor, qjl_signs: Tensor, residual_norms: Tensor
) -> Tensor

Reconstruct vectors from compressed representation.

Note: Full reconstruction is approximate. For attention computation, use estimate_inner_product instead — it's more accurate because QJL corrects inner-product bias, not reconstruction bias.

Parameters:

Name Type Description Default
indices Tensor

Lloyd-Max centroid indices, shape (..., dim).

required
norms Tensor

Vector norms, shape (..., 1).

required
qjl_signs Tensor

QJL sign bits, shape (..., qjl_dim).

required
residual_norms Tensor

Residual norms, shape (..., 1).

required

Returns:

Type Description
Tensor

Approximately reconstructed tensor of shape (..., dim).

Source code in src/turboquant_vllm/quantizer.py
def dequantize(
    self,
    indices: torch.Tensor,
    norms: torch.Tensor,
    qjl_signs: torch.Tensor,
    residual_norms: torch.Tensor,
) -> torch.Tensor:
    """Reconstruct vectors from compressed representation.

    Note: Full reconstruction is approximate. For attention computation,
    use ``estimate_inner_product`` instead — it's more accurate because
    QJL corrects inner-product bias, not reconstruction bias.

    Args:
        indices: Lloyd-Max centroid indices, shape (..., dim).
        norms: Vector norms, shape (..., 1).
        qjl_signs: QJL sign bits, shape (..., qjl_dim).
        residual_norms: Residual norms, shape (..., 1).

    Returns:
        Approximately reconstructed tensor of shape (..., dim).
    """
    return self.mse_quantizer.dequantize(indices, norms)
estimate_inner_product
estimate_inner_product(
    query: Tensor,
    indices: Tensor,
    norms: Tensor,
    qjl_signs: Tensor,
    residual_norms: Tensor,
) -> Tensor

Compute unbiased inner product estimate between query and compressed key.

Uses the two-stage estimator

~ + ||r|| * sqrt(pi/2) / m * <S@q, signs>

Parameters:

Name Type Description Default
query Tensor

Query vectors, shape (..., dim).

required
indices Tensor

Compressed key indices, shape (..., dim).

required
norms Tensor

Key norms, shape (..., 1).

required
qjl_signs Tensor

QJL sign bits for keys, shape (..., qjl_dim).

required
residual_norms Tensor

Key residual norms, shape (..., 1).

required

Returns:

Type Description
Tensor

Inner product estimates, shape matching broadcast of query and key

Tensor

batch dimensions.

Source code in src/turboquant_vllm/quantizer.py
def estimate_inner_product(
    self,
    query: torch.Tensor,
    indices: torch.Tensor,
    norms: torch.Tensor,
    qjl_signs: torch.Tensor,
    residual_norms: torch.Tensor,
) -> torch.Tensor:
    """Compute unbiased inner product estimate between query and compressed key.

    Uses the two-stage estimator:
        <q, k> ~ <q, k_mse> + ||r|| * sqrt(pi/2) / m * <S@q, signs>

    Args:
        query: Query vectors, shape (..., dim).
        indices: Compressed key indices, shape (..., dim).
        norms: Key norms, shape (..., 1).
        qjl_signs: QJL sign bits for keys, shape (..., qjl_dim).
        residual_norms: Key residual norms, shape (..., 1).

    Returns:
        Inner product estimates, shape matching broadcast of query and key
        batch dimensions.
    """
    # MSE component: <q, k_mse>
    k_mse = self.mse_quantizer.dequantize(indices, norms)
    mse_term = (query.float() * k_mse).sum(dim=-1, keepdim=True)

    # QJL correction: ||r|| * sqrt(pi/2) / m * <S@q, signs>
    s = self.qjl_matrix.to(query.device)
    q_projected = query.float().reshape(-1, self.dim) @ s.T
    q_projected = q_projected.reshape(*query.shape[:-1], self.qjl_dim)

    qjl_correction = (q_projected * qjl_signs).sum(dim=-1, keepdim=True)
    scale = residual_norms * math.sqrt(math.pi / 2.0) / self.qjl_dim
    qjl_term = scale * qjl_correction

    return mse_term + qjl_term

Functions

solve_lloyd_max

solve_lloyd_max(
    d: int,
    bits: int,
    *,
    use_exact: bool = False,
    max_iter: int = 200,
    tol: float = 1e-10,
) -> tuple[Tensor, Tensor]

Solve the Lloyd-Max conditions for optimal scalar quantization.

Results are cached by (d, bits, use_exact) so that multi-layer models (e.g., 32 layers × 2 K/V compressors = 64 calls) pay the scipy integration cost only once. Without caching, initialization takes 2+ minutes for models like Molmo2-8B.

Parameters:

Name Type Description Default
d int

Vector dimension (determines the distribution shape).

required
bits int

Number of quantization bits (produces 2^bits centroids).

required
use_exact bool

If True, use exact Beta PDF. If False, use Gaussian approximation (faster, accurate for d >= 64).

False
max_iter int

Maximum Lloyd-Max iterations.

200
tol float

Convergence tolerance on centroid movement.

1e-10

Returns:

Type Description
Tensor

Tuple of (centroids, boundaries) as 1-D tensors. Centroids has

Tensor

length 2^bits, boundaries has length 2^bits - 1.

Source code in src/turboquant_vllm/lloyd_max.py
def solve_lloyd_max(
    d: int,
    bits: int,
    *,
    use_exact: bool = False,
    max_iter: int = 200,
    tol: float = 1e-10,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Solve the Lloyd-Max conditions for optimal scalar quantization.

    Results are cached by (d, bits, use_exact) so that multi-layer models
    (e.g., 32 layers × 2 K/V compressors = 64 calls) pay the scipy
    integration cost only once. Without caching, initialization takes
    2+ minutes for models like Molmo2-8B.

    Args:
        d: Vector dimension (determines the distribution shape).
        bits: Number of quantization bits (produces 2^bits centroids).
        use_exact: If True, use exact Beta PDF. If False, use Gaussian
            approximation (faster, accurate for d >= 64).
        max_iter: Maximum Lloyd-Max iterations.
        tol: Convergence tolerance on centroid movement.

    Returns:
        Tuple of (centroids, boundaries) as 1-D tensors. Centroids has
        length 2^bits, boundaries has length 2^bits - 1.
    """
    return _solve_lloyd_max_cached(d, bits, use_exact, max_iter, tol)