Skip to content

vllm

turboquant_vllm.vllm

TQ4 compressed KV cache backend for vLLM.

Registers a custom attention backend that stores KV cache pages in TurboQuant 4-bit format (68 bytes/token/head vs 256 bytes FP16 = 3.76x compression).

Attributes:

Name Type Description
TQ4AttentionBackend

Custom attention backend registered as CUSTOM.

TQ4AttentionImpl

Attention implementation (passthrough in Phase 3a).

register_tq4_backend None

Callable to register the backend manually.

See Also

:mod:turboquant_vllm.kv_cache: CompressedDynamicCache for HF transformers.

Usage

The backend registers automatically via the vllm.general_plugins entry point when turboquant-vllm is installed with the vllm extra::

pip install turboquant-vllm[vllm]
vllm serve <model> --attention-backend CUSTOM

Or register manually before starting vLLM::

from turboquant_vllm.vllm import register_tq4_backend

register_tq4_backend()

Classes

TQ4AttentionBackend

Bases: FlashAttentionBackend

TQ4 compressed KV cache attention backend.

Phase 3c: packed uint8 cache layout with real VRAM savings. The cache stores nibble-packed TQ4 indices + fp32 norms as raw bytes. get_kv_cache_shape() returns a 3D (NB, BS, bytes_per_token) layout matching the packed format.

Functions
supports_mm_prefix classmethod
supports_mm_prefix() -> bool

Required for VLMs like Molmo2 with bidirectional visual tokens.

Source code in src/turboquant_vllm/vllm/tq4_backend.py
@classmethod
def supports_mm_prefix(cls) -> bool:
    """Required for VLMs like Molmo2 with bidirectional visual tokens."""
    return True
get_name staticmethod
get_name() -> str

Must return "CUSTOM" to match AttentionBackendEnum.CUSTOM.

Source code in src/turboquant_vllm/vllm/tq4_backend.py
@staticmethod
def get_name() -> str:
    """Must return ``"CUSTOM"`` to match ``AttentionBackendEnum.CUSTOM``."""
    return "CUSTOM"
get_impl_cls staticmethod
get_impl_cls() -> type[AttentionImplBase]

Return :class:TQ4AttentionImpl.

Source code in src/turboquant_vllm/vllm/tq4_backend.py
@staticmethod
def get_impl_cls() -> type[AttentionImplBase]:
    """Return :class:`TQ4AttentionImpl`."""
    return TQ4AttentionImpl
get_builder_cls staticmethod
get_builder_cls() -> type[AttentionMetadataBuilder]

Return :class:FlashAttentionMetadataBuilder -- reused.

Source code in src/turboquant_vllm/vllm/tq4_backend.py
@staticmethod
def get_builder_cls() -> type[AttentionMetadataBuilder]:
    """Return :class:`FlashAttentionMetadataBuilder` -- reused."""
    return FlashAttentionMetadataBuilder
get_kv_cache_shape staticmethod
get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]

Packed TQ4 cache: (num_blocks, block_size, total_bytes).

The last dimension packs K and V data for all heads as raw bytes: [K_indices | K_norms | V_indices | V_norms].

Source code in src/turboquant_vllm/vllm/tq4_backend.py
@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
    """Packed TQ4 cache: ``(num_blocks, block_size, total_bytes)``.

    The last dimension packs K and V data for all heads as raw bytes:
    ``[K_indices | K_norms | V_indices | V_norms]``.
    """
    total_bytes = num_kv_heads * _tq4_bytes_per_token_kv(head_size)
    return (num_blocks, block_size, total_bytes)
get_kv_cache_stride_order staticmethod
get_kv_cache_stride_order(
    include_num_layers_dimension: bool = False,
) -> tuple[int, ...]

Raise to trigger identity fallback in reshape.

The inherited FlashAttentionBackend returns a 5-element stride order for the standard (2, NB, BS, H, D) shape. Our 3D packed layout (NB, BS, total_bytes) needs identity ordering. Raising NotImplementedError triggers the fallback in _reshape_kv_cache_tensors (same pattern as FlashMLA which does not implement this method at all).

Source code in src/turboquant_vllm/vllm/tq4_backend.py
@staticmethod
def get_kv_cache_stride_order(
    include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
    """Raise to trigger identity fallback in reshape.

    The inherited FlashAttentionBackend returns a 5-element stride
    order for the standard ``(2, NB, BS, H, D)`` shape. Our 3D
    packed layout ``(NB, BS, total_bytes)`` needs identity ordering.
    Raising ``NotImplementedError`` triggers the fallback in
    ``_reshape_kv_cache_tensors`` (same pattern as FlashMLA which
    does not implement this method at all).
    """
    raise NotImplementedError

TQ4AttentionImpl

TQ4AttentionImpl(*args, **kwargs)

Bases: FlashAttentionImpl

TQ4 attention: compress -> store -> decompress -> Flash Attention.

Phase 3c: stores packed TQ4 bytes in a uint8 cache for real VRAM savings. Each forward() call:

  1. Compresses incoming K/V tokens to TQ4 packed bytes.
  2. Scatter-writes packed bytes to the uint8 cache via slot_mapping.
  3. Decompresses the full cache to FP16 for Flash Attention.
  4. Calls flash_attn_varlen_func directly with the FP16 data.

Initialize TQ4 attention with compression primitives.

Source code in src/turboquant_vllm/vllm/tq4_backend.py
def __init__(self, *args, **kwargs) -> None:
    """Initialize TQ4 attention with compression primitives."""
    super().__init__(*args, **kwargs)

    # Use attributes set by super().__init__()
    head_size = self.head_size
    num_kv_heads = self.num_kv_heads

    # TQ4 compression primitives (deterministic from seed, shared across layers)
    quantizer = TurboQuantMSE(head_size, TQ4_BITS, seed=TQ4_SEED)
    self._tq4_rotation = quantizer.rotation  # (D, D) fp32
    self._tq4_centroids = quantizer.codebook.centroids  # (16,) fp32
    self._tq4_boundaries = quantizer.codebook.boundaries  # (15,) fp32
    # Pre-split rotation.T for fused compress kernel (contiguous loads)
    rot_t = quantizer.rotation.T.contiguous()
    self._tq4_rot_T_even = rot_t[:, 0::2].contiguous()  # (D, D//2) fp32
    self._tq4_rot_T_odd = rot_t[:, 1::2].contiguous()  # (D, D//2) fp32
    self._tq4_on_device = False

    # Byte layout offsets within the last dimension of the packed cache.
    # Layout: [K_indices(H*D/2) | K_norms(H*4) | V_indices(H*D/2) | V_norms(H*4)]
    half_D = head_size // 2
    self._half_D = half_D
    self._k_idx_end = num_kv_heads * half_D
    self._k_norm_end = self._k_idx_end + num_kv_heads * TQ4_NORM_BYTES
    self._v_idx_end = self._k_norm_end + num_kv_heads * half_D
    self._total_bytes = self._v_idx_end + num_kv_heads * TQ4_NORM_BYTES

    logger.info(
        "TQ4AttentionImpl: %d KV heads, head_size=%d, "
        "%d bytes/token (%.2fx compression vs FP16)",
        num_kv_heads,
        head_size,
        self._total_bytes,
        (2 * num_kv_heads * head_size * 2) / self._total_bytes,
    )
Functions
forward
forward(
    layer,
    query,
    key,
    value,
    kv_cache,
    attn_metadata,
    output=None,
    output_scale=None,
    output_block_scale=None,
)

TQ4 attention: compress -> store -> pre-rotate Q -> decompress -> FA -> post-rotate.

Phase 3c.8: Uses fused Triton decompress (no rotation). The rotation is applied to Q before attention and to the output after, saving O(cache_len) matmuls per decode step.

Source code in src/turboquant_vllm/vllm/tq4_backend.py
def forward(
    self,
    layer,
    query,
    key,
    value,
    kv_cache,
    attn_metadata,
    output=None,
    output_scale=None,
    output_block_scale=None,
):
    """TQ4 attention: compress -> store -> pre-rotate Q -> decompress -> FA -> post-rotate.

    Phase 3c.8: Uses fused Triton decompress (no rotation). The
    rotation is applied to Q before attention and to the output
    after, saving O(cache_len) matmuls per decode step.
    """
    assert output is not None

    if output_scale is not None or output_block_scale is not None:
        raise NotImplementedError(
            "Fused output quantization is not supported with TQ4 backend"
        )

    # Profiling mode
    if attn_metadata is None:
        output.zero_()
        return output

    # Encoder attention: no TQ4, delegate to parent
    # (VIT uses a separate backend, but guard just in case)
    from vllm.v1.attention.backend import AttentionType

    if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
        return self._forward_encoder_attention(
            query[: attn_metadata.num_actual_tokens],
            key[: attn_metadata.num_actual_tokens],
            value[: attn_metadata.num_actual_tokens],
            output[: attn_metadata.num_actual_tokens],
            attn_metadata,
            layer,
        )

    num_actual_tokens = attn_metadata.num_actual_tokens

    # Step 1: Compress and store new K/V tokens
    if kv_cache is not None and key is not None and value is not None:
        self._ensure_device(query.device)
        self._compress_and_store(key, value, kv_cache, attn_metadata.slot_mapping)

    # Step 2: Pre-rotate Q by Pi^T (O(num_actual_tokens), not O(cache_len))
    self._ensure_device(query.device)
    q_slice = query[:num_actual_tokens]
    q_rot = (q_slice.float() @ self._tq4_rotation.T).to(q_slice.dtype)

    # Step 3: Decompress full cache (Triton fused, skip rotation)
    key_cache, value_cache = self._decompress_cache(
        kv_cache,
        query.dtype,
        apply_rotation=False,
    )

    # Step 4: Run Flash Attention with rotated Q and rotated KV
    from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func

    if attn_metadata.use_cascade:
        raise NotImplementedError("TQ4 does not yet support cascade attention")

    descale_shape = (
        attn_metadata.query_start_loc.shape[0] - 1,
        self.num_kv_heads,
    )
    q_descale = layer._q_scale.expand(descale_shape)
    k_descale = layer._k_scale.expand(descale_shape)
    v_descale = layer._v_scale.expand(descale_shape)

    flash_attn_varlen_func(
        q=q_rot,
        k=key_cache,
        v=value_cache,
        out=output[:num_actual_tokens],
        cu_seqlens_q=attn_metadata.query_start_loc,
        max_seqlen_q=attn_metadata.max_query_len,
        seqused_k=attn_metadata.seq_lens,
        max_seqlen_k=attn_metadata.max_seq_len,
        softmax_scale=self.scale,
        causal=attn_metadata.causal,
        alibi_slopes=self.alibi_slopes,  # ty: ignore[invalid-argument-type]
        window_size=list(self.sliding_window)
        if self.sliding_window is not None
        else None,
        block_table=attn_metadata.block_table,
        softcap=self.logits_soft_cap,
        scheduler_metadata=attn_metadata.scheduler_metadata,
        fa_version=self.vllm_flash_attn_version,  # ty: ignore[invalid-argument-type]
        q_descale=q_descale,
        k_descale=k_descale,
        v_descale=v_descale,
        num_splits=attn_metadata.max_num_splits,
        s_aux=self.sinks,
    )

    # Step 5: Post-rotate output by Pi (undo rotation space)
    out_slice = output[:num_actual_tokens]
    output[:num_actual_tokens] = (out_slice.float() @ self._tq4_rotation).to(
        out_slice.dtype
    )

    return output

TQ4FullAttentionSpec dataclass

TQ4FullAttentionSpec()

Bases: FullAttentionSpec

KV cache spec with TQ4 packed page size.

Overrides real_page_size_bytes so the block allocator provisions buffers sized for the packed TQ4 format (3.76x smaller than FP16). Follows the same pattern as MLAAttentionSpec which overrides page size for the 656-byte FlashMLA format.

Functions

register_tq4_backend

register_tq4_backend() -> None

Register TQ4 as the CUSTOM attention backend.

In addition to registering the backend class, this monkey-patches Attention.get_kv_cache_spec so that decoder attention layers return :class:TQ4FullAttentionSpec (with dtype=torch.uint8 and TQ4-sized pages) instead of the standard FullAttentionSpec.

Called automatically by the vllm.general_plugins entry point, or manually before starting vLLM::

from turboquant_vllm.vllm import register_tq4_backend

register_tq4_backend()
# then start vLLM with --attention-backend CUSTOM
Source code in src/turboquant_vllm/vllm/tq4_backend.py
def register_tq4_backend() -> None:
    """Register TQ4 as the CUSTOM attention backend.

    In addition to registering the backend class, this monkey-patches
    ``Attention.get_kv_cache_spec`` so that decoder attention layers
    return :class:`TQ4FullAttentionSpec` (with ``dtype=torch.uint8``
    and TQ4-sized pages) instead of the standard ``FullAttentionSpec``.

    Called automatically by the ``vllm.general_plugins`` entry point,
    or manually before starting vLLM::

        from turboquant_vllm.vllm import register_tq4_backend

        register_tq4_backend()
        # then start vLLM with --attention-backend CUSTOM
    """
    global _original_get_kv_cache_spec  # noqa: PLW0603

    register_backend(
        AttentionBackendEnum.CUSTOM,
        "turboquant_vllm.vllm.tq4_backend.TQ4AttentionBackend",
    )

    # Register TQ4FullAttentionSpec in the KV cache manager mapping.
    # vLLM uses exact type() match, not isinstance(), so subclasses
    # of FullAttentionSpec must be explicitly added.
    from vllm.v1.core.single_type_kv_cache_manager import spec_manager_map

    if TQ4FullAttentionSpec not in spec_manager_map:
        spec_manager_map[TQ4FullAttentionSpec] = spec_manager_map[FullAttentionSpec]

    # Monkey-patch Attention.get_kv_cache_spec to return TQ4 spec
    from vllm.model_executor.layers.attention.attention import Attention

    if _original_get_kv_cache_spec is None:
        _original_get_kv_cache_spec = Attention.get_kv_cache_spec

    def _tq4_get_kv_cache_spec(self, vllm_config):
        spec = _original_get_kv_cache_spec(self, vllm_config)
        if isinstance(spec, FullAttentionSpec) and not isinstance(
            spec, TQ4FullAttentionSpec
        ):
            kwargs = {f.name: getattr(spec, f.name) for f in dc_fields(spec)}
            kwargs["dtype"] = torch.uint8
            return TQ4FullAttentionSpec(**kwargs)
        return spec

    Attention.get_kv_cache_spec = _tq4_get_kv_cache_spec  # ty: ignore[invalid-assignment]
    logger.info("TQ4 attention backend registered as CUSTOM (packed cache)")