Skip to content

vllm

turboquant_vllm.vllm

TQ4 compressed KV cache backend for vLLM.

Registers a custom attention backend that stores KV cache pages in TurboQuant 4-bit format (68 bytes/token/head vs 256 bytes FP16 = 3.76x compression).

Attributes:

Name Type Description
TQ4AttentionBackend

Custom attention backend registered as CUSTOM.

TQ4AttentionImpl

Attention implementation (passthrough in Phase 3a).

register_tq4_backend None

Callable to register the backend manually.

See Also

:mod:turboquant_vllm.kv_cache: CompressedDynamicCache for HF transformers.

Usage

The backend registers automatically via the vllm.general_plugins entry point when turboquant-vllm is installed with the vllm extra::

pip install turboquant-vllm[vllm]
vllm serve <model> --attention-backend CUSTOM

Or register manually before starting vLLM::

from turboquant_vllm.vllm import register_tq4_backend

register_tq4_backend()

Classes

TQ4AttentionBackend

Bases: FlashAttentionBackend

TQ4 compressed KV cache attention backend.

Phase 3c: packed uint8 cache layout with real VRAM savings. The cache stores nibble-packed TQ4 indices + fp32 norms as raw bytes. get_kv_cache_shape() returns a 3D (NB, BS, bytes_per_token) layout matching the packed format.

Functions
supports_mm_prefix classmethod
supports_mm_prefix() -> bool

Required for VLMs like Molmo2 with bidirectional visual tokens.

Source code in src/turboquant_vllm/vllm/tq4_backend.py
@classmethod
def supports_mm_prefix(cls) -> bool:
    """Required for VLMs like Molmo2 with bidirectional visual tokens."""
    return True
get_name staticmethod
get_name() -> str

Must return "CUSTOM" to match AttentionBackendEnum.CUSTOM.

Source code in src/turboquant_vllm/vllm/tq4_backend.py
@staticmethod
def get_name() -> str:
    """Must return ``"CUSTOM"`` to match ``AttentionBackendEnum.CUSTOM``."""
    return "CUSTOM"
get_impl_cls staticmethod
get_impl_cls() -> type[AttentionImplBase]

Return :class:TQ4AttentionImpl.

Source code in src/turboquant_vllm/vllm/tq4_backend.py
@staticmethod
def get_impl_cls() -> type[AttentionImplBase]:
    """Return :class:`TQ4AttentionImpl`."""
    return TQ4AttentionImpl
get_builder_cls staticmethod
get_builder_cls() -> type[AttentionMetadataBuilder]

Return :class:TQ4MetadataBuilder for CUDA graph support.

Source code in src/turboquant_vllm/vllm/tq4_backend.py
@staticmethod
def get_builder_cls() -> type[AttentionMetadataBuilder]:
    """Return :class:`TQ4MetadataBuilder` for CUDA graph support."""
    return TQ4MetadataBuilder
get_kv_cache_shape staticmethod
get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]

Packed TQ4 cache: (num_blocks, block_size, padded_bytes).

The last dimension packs K and V data for all heads as raw bytes with padding for hybrid model page alignment. Only the first num_kv_heads * _tq4_bytes_per_token_kv(head_size) bytes per token contain packed data; trailing bytes are unused padding.

Source code in src/turboquant_vllm/vllm/tq4_backend.py
@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
    """Packed TQ4 cache: ``(num_blocks, block_size, padded_bytes)``.

    The last dimension packs K and V data for all heads as raw bytes
    with padding for hybrid model page alignment. Only the first
    ``num_kv_heads * _tq4_bytes_per_token_kv(head_size)`` bytes per
    token contain packed data; trailing bytes are unused padding.
    """
    total_bytes = num_kv_heads * _padded_slot_bytes(head_size)
    return (num_blocks, block_size, total_bytes)
get_kv_cache_stride_order staticmethod
get_kv_cache_stride_order(
    include_num_layers_dimension: bool = False,
) -> tuple[int, ...]

Raise to trigger identity fallback in reshape.

The inherited FlashAttentionBackend returns a 5-element stride order for the standard (2, NB, BS, H, D) shape. Our 3D packed layout (NB, BS, total_bytes) needs identity ordering. Raising NotImplementedError triggers the fallback in _reshape_kv_cache_tensors (same pattern as FlashMLA which does not implement this method at all).

Source code in src/turboquant_vllm/vllm/tq4_backend.py
@staticmethod
def get_kv_cache_stride_order(
    include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
    """Raise to trigger identity fallback in reshape.

    The inherited FlashAttentionBackend returns a 5-element stride
    order for the standard ``(2, NB, BS, H, D)`` shape. Our 3D
    packed layout ``(NB, BS, total_bytes)`` needs identity ordering.
    Raising ``NotImplementedError`` triggers the fallback in
    ``_reshape_kv_cache_tensors`` (same pattern as FlashMLA which
    does not implement this method at all).
    """
    raise NotImplementedError

TQ4AttentionImpl

TQ4AttentionImpl(*args, **kwargs)

Bases: FlashAttentionImpl

TQ4 attention: compress -> store -> decompress -> Flash Attention.

Phase 3c: stores packed TQ4 bytes in a uint8 cache for real VRAM savings. Each forward() call:

  1. Compresses incoming K/V tokens to TQ4 packed bytes.
  2. Scatter-writes packed bytes to the uint8 cache via slot_mapping.
  3. Decompresses the full cache to FP16 for Flash Attention.
  4. Calls flash_attn_varlen_func directly with the FP16 data.

Initialize TQ4 attention with compression primitives.

Source code in src/turboquant_vllm/vllm/tq4_backend.py
def __init__(self, *args, **kwargs) -> None:
    """Initialize TQ4 attention with compression primitives."""
    super().__init__(*args, **kwargs)

    # Use attributes set by super().__init__()
    head_size = self.head_size
    num_kv_heads = self.num_kv_heads

    # Resolve per-component bit-widths from env vars
    k_bits, v_bits = _parse_kv_bits_env()
    self._k_bits = k_bits
    self._v_bits = v_bits

    # TQ4 compression primitives (deterministic from seed, shared across layers)
    # Rotation matrix is dim-dependent only (not bits-dependent), so shared.
    k_quantizer = TurboQuantMSE(head_size, k_bits, seed=TQ4_SEED)
    v_quantizer = (
        k_quantizer
        if v_bits == k_bits
        else TurboQuantMSE(head_size, v_bits, seed=TQ4_SEED)
    )

    # Eagerly move primitives to the target device (D7 mod 5).
    # FlashAttentionImpl.__init__ doesn't expose device, but
    # vLLM's global config is available during model construction.
    from vllm.config import get_current_vllm_config_or_none

    vllm_config = get_current_vllm_config_or_none()
    device = (
        vllm_config.device_config.device
        if vllm_config is not None
        else torch.device("cpu")
    )

    # Shared rotation (dim-dependent only, same seed → identical matrix)
    self._tq4_rotation = k_quantizer.rotation.to(device)  # (D, D) fp32
    # Pre-split rotation.T for fused compress kernel (contiguous loads)
    rot_t = k_quantizer.rotation.T.contiguous()
    self._tq4_rot_T_even = rot_t[:, 0::2].contiguous().to(device)  # (D, D//2) fp32
    self._tq4_rot_T_odd = rot_t[:, 1::2].contiguous().to(device)  # (D, D//2) fp32

    # Per-component codebooks (may differ when k_bits != v_bits)
    self._k_centroids = k_quantizer.codebook.centroids.to(device)
    self._k_boundaries = k_quantizer.codebook.boundaries.to(device)
    self._v_centroids = v_quantizer.codebook.centroids.to(device)
    self._v_boundaries = v_quantizer.codebook.boundaries.to(device)

    # Byte layout offsets within the last dimension of the packed cache.
    # Layout: [K_indices | K_norms | V_indices | V_norms]
    # The Triton compress kernel always nibble-packs (head_dim // 2 bytes
    # per component), regardless of bit-width. Different codebook sizes
    # provide quality improvement, not storage savings in the vLLM path.
    k_idx_size = head_size // 2
    v_idx_size = head_size // 2
    self._k_idx_size = k_idx_size
    self._v_idx_size = v_idx_size
    self._k_idx_end = num_kv_heads * k_idx_size
    self._k_norm_end = self._k_idx_end + num_kv_heads * TQ4_NORM_BYTES
    self._v_idx_end = self._k_norm_end + num_kv_heads * v_idx_size
    self._total_bytes = self._v_idx_end + num_kv_heads * TQ4_NORM_BYTES

    # CUDA graph scratch buffers (D7 mod 2) — lazy-allocated on first
    # forward() from kv_cache.shape, which is stable for engine lifetime.
    # First forward runs during vLLM warmup, before graph capture.
    self._cg_buffers_ready = False

    # Fused paged decode feature gate (Story 6.3, AC 1+6).
    # Explicit opt-in via TQ4_USE_FUSED_PAGED env var AND successful
    # kernel import.  Disabled for asymmetric configs — fused kernel
    # uses a single codebook and cannot handle k_bits != v_bits.
    self._fused_paged_available = (
        _parse_fused_paged_env()
        and _fused_paged_kernel_available
        and k_bits == v_bits
    )

    # INT8 prefill gate (Story 6.4): requires fused decode gate + its own
    # env var + successful kernel import.
    self._int8_prefill_available = (
        self._fused_paged_available
        and _parse_int8_prefill_env()
        and _int8_prefill_kernel_available
    )

    # Buffer downsizing source: scheduler knows its own max prefill length.
    # Fallback 2048 matches vLLM's default max_num_batched_tokens for
    # chunked prefill.
    self._max_prefill_len = (
        vllm_config.scheduler_config.max_num_batched_tokens
        if vllm_config is not None
        else 2048
    )

    # Decode buffer bound: max_model_len caps decompress buffer instead
    # of full cache capacity.  Fallback 6144 matches Molmo2 default.
    self._max_model_len = (
        vllm_config.model_config.max_model_len if vllm_config is not None else 6144
    )

    logger.info(
        "TQ4AttentionImpl: %d KV heads, head_size=%d, k_bits=%d, v_bits=%d, "
        "%d bytes/token (%.2fx compression vs FP16)",
        num_kv_heads,
        head_size,
        k_bits,
        v_bits,
        self._total_bytes,
        (2 * num_kv_heads * head_size * 2) / self._total_bytes,
    )
    logger.info(
        "Fused paged TQ4 decode: %s",
        "enabled" if self._fused_paged_available else "disabled",
    )
    logger.info(
        "INT8 prefill path: %s",
        "enabled" if self._int8_prefill_available else "disabled",
    )
Functions
forward
forward(
    layer,
    query,
    key,
    value,
    kv_cache,
    attn_metadata,
    output=None,
    output_scale=None,
    output_block_scale=None,
)

TQ4 attention: compress -> store -> pre-rotate Q -> decompress -> FA -> post-rotate.

Phase 3c.8: Uses fused Triton decompress (no rotation). The rotation is applied to Q before attention and to the output after, saving O(cache_len) matmuls per decode step.

Source code in src/turboquant_vllm/vllm/tq4_backend.py
def forward(
    self,
    layer,
    query,
    key,
    value,
    kv_cache,
    attn_metadata,
    output=None,
    output_scale=None,
    output_block_scale=None,
):
    """TQ4 attention: compress -> store -> pre-rotate Q -> decompress -> FA -> post-rotate.

    Phase 3c.8: Uses fused Triton decompress (no rotation). The
    rotation is applied to Q before attention and to the output
    after, saving O(cache_len) matmuls per decode step.
    """
    assert output is not None

    if output_scale is not None or output_block_scale is not None:
        raise NotImplementedError(
            "Fused output quantization is not supported with TQ4 backend"
        )

    # Profiling mode
    if attn_metadata is None:
        output.zero_()
        return output

    # Warmup with no cache allocated yet
    if kv_cache is None:
        output.zero_()
        return output

    # Encoder attention: no TQ4, delegate to parent
    # (VIT uses a separate backend, but guard just in case)
    from vllm.v1.attention.backend import AttentionType

    if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
        return self._forward_encoder_attention(
            query[: attn_metadata.num_actual_tokens],
            key[: attn_metadata.num_actual_tokens],
            value[: attn_metadata.num_actual_tokens],
            output[: attn_metadata.num_actual_tokens],
            attn_metadata,
            layer,
        )

    num_actual_tokens = attn_metadata.num_actual_tokens

    # Lazy-init CUDA graph buffers on first forward (during warmup)
    if not self._cg_buffers_ready and kv_cache is not None:
        self._init_cg_buffers(kv_cache, compute_dtype=query.dtype)

    # Steps 1-3: compress, rotate Q, decompress (decode vs prefill path)
    is_decode = self._cg_buffers_ready and num_actual_tokens == 1

    # Fused paged decode (Story 6.3): single kernel replaces
    # decompress + FlashAttn + post-rotate for decode steps.
    if self._fused_paged_available and is_decode:
        return self._fused_decode_path(
            query, key, value, kv_cache, attn_metadata, output
        )

    # INT8 prefill (Story 6.4): IMMA tensor core Q@K^T for prefill.
    # Guard: kernel is single-sequence only; fall back for multi-sequence
    # batches (vLLM scheduler may combine multiple requests).
    if (
        self._int8_prefill_available
        and not is_decode
        and attn_metadata.seq_lens.shape[0] == 1
    ):
        return self._int8_prefill_path(
            query, key, value, kv_cache, attn_metadata, output
        )

    if is_decode:
        q_rot, key_cache, value_cache, fa_block_table = self._tq4_decode(
            query, key, value, kv_cache, attn_metadata
        )
    else:
        q_rot, key_cache, value_cache, fa_block_table = self._tq4_prefill(
            query, key, value, kv_cache, attn_metadata
        )

    # Step 4: Run Flash Attention with rotated Q and rotated KV
    from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func

    if attn_metadata.use_cascade:
        raise NotImplementedError("TQ4 does not yet support cascade attention")

    descale_shape = (
        attn_metadata.query_start_loc.shape[0] - 1,
        self.num_kv_heads,
    )
    q_descale = layer._q_scale.expand(descale_shape)
    k_descale = layer._k_scale.expand(descale_shape)
    v_descale = layer._v_scale.expand(descale_shape)

    flash_attn_varlen_func(
        q=q_rot,
        k=key_cache,
        v=value_cache,
        out=output[:num_actual_tokens],
        cu_seqlens_q=attn_metadata.query_start_loc,
        max_seqlen_q=attn_metadata.max_query_len,
        seqused_k=attn_metadata.seq_lens,
        max_seqlen_k=attn_metadata.max_seq_len,
        softmax_scale=self.scale,
        causal=attn_metadata.causal,
        alibi_slopes=self.alibi_slopes,
        window_size=list(self.sliding_window)
        if self.sliding_window is not None
        else None,
        block_table=fa_block_table,
        softcap=self.logits_soft_cap,
        scheduler_metadata=attn_metadata.scheduler_metadata,
        fa_version=self.vllm_flash_attn_version,
        q_descale=q_descale,
        k_descale=k_descale,
        v_descale=v_descale,
        num_splits=attn_metadata.max_num_splits,
        s_aux=self.sinks,
    )

    # Step 5: Post-rotate output by Pi (undo rotation space)
    out_slice = output[:num_actual_tokens]
    output[:num_actual_tokens] = (out_slice.float() @ self._tq4_rotation).to(
        out_slice.dtype
    )

    return output

TQ4FullAttentionSpec dataclass

TQ4FullAttentionSpec()

Bases: FullAttentionSpec

KV cache spec with TQ4 packed page size.

Overrides real_page_size_bytes so the block allocator provisions buffers sized for the packed TQ4 format. Supports asymmetric K/V bit-widths via TQ4_K_BITS / TQ4_V_BITS env vars. Follows the same pattern as MLAAttentionSpec which overrides page size for the 656-byte FlashMLA format.

Functions

register_tq4_backend

register_tq4_backend() -> None

Register TQ4 as the CUSTOM attention backend.

In addition to registering the backend class, this monkey-patches Attention.get_kv_cache_spec so that decoder attention layers return :class:TQ4FullAttentionSpec (with dtype=torch.uint8 and TQ4-sized pages) instead of the standard FullAttentionSpec.

Called automatically by the vllm.general_plugins entry point, or manually before starting vLLM::

from turboquant_vllm.vllm import register_tq4_backend

register_tq4_backend()
# then start vLLM with --attention-backend CUSTOM
Source code in src/turboquant_vllm/vllm/tq4_backend.py
def register_tq4_backend() -> None:
    """Register TQ4 as the CUSTOM attention backend.

    In addition to registering the backend class, this monkey-patches
    ``Attention.get_kv_cache_spec`` so that decoder attention layers
    return :class:`TQ4FullAttentionSpec` (with ``dtype=torch.uint8``
    and TQ4-sized pages) instead of the standard ``FullAttentionSpec``.

    Called automatically by the ``vllm.general_plugins`` entry point,
    or manually before starting vLLM::

        from turboquant_vllm.vllm import register_tq4_backend

        register_tq4_backend()
        # then start vLLM with --attention-backend CUSTOM
    """
    global _original_get_kv_cache_spec  # noqa: PLW0603

    register_backend(
        AttentionBackendEnum.CUSTOM,
        "turboquant_vllm.vllm.tq4_backend.TQ4AttentionBackend",
    )

    # Register TQ4FullAttentionSpec in the KV cache manager mapping.
    # vLLM uses exact type() match, not isinstance(), so subclasses
    # of FullAttentionSpec must be explicitly added.
    from vllm.v1.core.single_type_kv_cache_manager import spec_manager_map

    if TQ4FullAttentionSpec not in spec_manager_map:
        spec_manager_map[TQ4FullAttentionSpec] = spec_manager_map[FullAttentionSpec]

    # Monkey-patch Attention.get_kv_cache_spec to return TQ4 spec
    from vllm.model_executor.layers.attention.attention import Attention

    if _original_get_kv_cache_spec is None:
        _original_get_kv_cache_spec = Attention.get_kv_cache_spec

    def _tq4_get_kv_cache_spec(self, vllm_config):
        spec = _original_get_kv_cache_spec(self, vllm_config)
        if isinstance(spec, FullAttentionSpec) and not isinstance(
            spec, TQ4FullAttentionSpec
        ):
            kwargs = {f.name: getattr(spec, f.name) for f in dc_fields(spec)}
            kwargs["dtype"] = torch.uint8
            return TQ4FullAttentionSpec(**kwargs)
        return spec

    Attention.get_kv_cache_spec = _tq4_get_kv_cache_spec
    logger.info("TQ4 attention backend registered as CUSTOM (packed cache)")