Skip to content

tq4_backend

turboquant_vllm.vllm.tq4_backend

TQ4 compressed KV cache attention backend for vLLM.

Phase 3c: Packed TQ4 cache layout with real VRAM savings.

The KV cache is stored as uint8 bytes in a packed TQ4 format (68 bytes per token per head per K/V = 136 bytes total vs 512 bytes FP16 = 3.76x compression). Buffer allocation uses a custom TQ4FullAttentionSpec that overrides page_size_bytes so the block allocator provisions 3.76x more blocks in the same VRAM budget. Each forward() call decompresses the relevant blocks to FP16 and delegates to Flash Attention.

Implementation phases

3a (done): Passthrough skeleton -- validated plugin wiring. 3b (done): Compress-decompress round-trip in standard FP16 cache. 3c (this): Packed uint8 cache with real VRAM savings. 3d: Production benchmark against vLLM baseline.

Classes

TQ4FullAttentionSpec dataclass

TQ4FullAttentionSpec()

Bases: FullAttentionSpec

KV cache spec with TQ4 packed page size.

Overrides real_page_size_bytes so the block allocator provisions buffers sized for the packed TQ4 format (3.76x smaller than FP16). Follows the same pattern as MLAAttentionSpec which overrides page size for the 656-byte FlashMLA format.

TQ4AttentionBackend

Bases: FlashAttentionBackend

TQ4 compressed KV cache attention backend.

Phase 3c: packed uint8 cache layout with real VRAM savings. The cache stores nibble-packed TQ4 indices + fp32 norms as raw bytes. get_kv_cache_shape() returns a 3D (NB, BS, bytes_per_token) layout matching the packed format.

Functions
supports_mm_prefix classmethod
supports_mm_prefix() -> bool

Required for VLMs like Molmo2 with bidirectional visual tokens.

Source code in src/turboquant_vllm/vllm/tq4_backend.py
@classmethod
def supports_mm_prefix(cls) -> bool:
    """Required for VLMs like Molmo2 with bidirectional visual tokens."""
    return True
get_name staticmethod
get_name() -> str

Must return "CUSTOM" to match AttentionBackendEnum.CUSTOM.

Source code in src/turboquant_vllm/vllm/tq4_backend.py
@staticmethod
def get_name() -> str:
    """Must return ``"CUSTOM"`` to match ``AttentionBackendEnum.CUSTOM``."""
    return "CUSTOM"
get_impl_cls staticmethod
get_impl_cls() -> type[AttentionImplBase]

Return :class:TQ4AttentionImpl.

Source code in src/turboquant_vllm/vllm/tq4_backend.py
@staticmethod
def get_impl_cls() -> type[AttentionImplBase]:
    """Return :class:`TQ4AttentionImpl`."""
    return TQ4AttentionImpl
get_builder_cls staticmethod
get_builder_cls() -> type[AttentionMetadataBuilder]

Return :class:FlashAttentionMetadataBuilder -- reused.

Source code in src/turboquant_vllm/vllm/tq4_backend.py
@staticmethod
def get_builder_cls() -> type[AttentionMetadataBuilder]:
    """Return :class:`FlashAttentionMetadataBuilder` -- reused."""
    return FlashAttentionMetadataBuilder
get_kv_cache_shape staticmethod
get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]

Packed TQ4 cache: (num_blocks, block_size, total_bytes).

The last dimension packs K and V data for all heads as raw bytes: [K_indices | K_norms | V_indices | V_norms].

Source code in src/turboquant_vllm/vllm/tq4_backend.py
@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
    """Packed TQ4 cache: ``(num_blocks, block_size, total_bytes)``.

    The last dimension packs K and V data for all heads as raw bytes:
    ``[K_indices | K_norms | V_indices | V_norms]``.
    """
    total_bytes = num_kv_heads * _tq4_bytes_per_token_kv(head_size)
    return (num_blocks, block_size, total_bytes)
get_kv_cache_stride_order staticmethod
get_kv_cache_stride_order(
    include_num_layers_dimension: bool = False,
) -> tuple[int, ...]

Raise to trigger identity fallback in reshape.

The inherited FlashAttentionBackend returns a 5-element stride order for the standard (2, NB, BS, H, D) shape. Our 3D packed layout (NB, BS, total_bytes) needs identity ordering. Raising NotImplementedError triggers the fallback in _reshape_kv_cache_tensors (same pattern as FlashMLA which does not implement this method at all).

Source code in src/turboquant_vllm/vllm/tq4_backend.py
@staticmethod
def get_kv_cache_stride_order(
    include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
    """Raise to trigger identity fallback in reshape.

    The inherited FlashAttentionBackend returns a 5-element stride
    order for the standard ``(2, NB, BS, H, D)`` shape. Our 3D
    packed layout ``(NB, BS, total_bytes)`` needs identity ordering.
    Raising ``NotImplementedError`` triggers the fallback in
    ``_reshape_kv_cache_tensors`` (same pattern as FlashMLA which
    does not implement this method at all).
    """
    raise NotImplementedError

TQ4AttentionImpl

TQ4AttentionImpl(*args, **kwargs)

Bases: FlashAttentionImpl

TQ4 attention: compress -> store -> decompress -> Flash Attention.

Phase 3c: stores packed TQ4 bytes in a uint8 cache for real VRAM savings. Each forward() call:

  1. Compresses incoming K/V tokens to TQ4 packed bytes.
  2. Scatter-writes packed bytes to the uint8 cache via slot_mapping.
  3. Decompresses the full cache to FP16 for Flash Attention.
  4. Calls flash_attn_varlen_func directly with the FP16 data.

Initialize TQ4 attention with compression primitives.

Source code in src/turboquant_vllm/vllm/tq4_backend.py
def __init__(self, *args, **kwargs) -> None:
    """Initialize TQ4 attention with compression primitives."""
    super().__init__(*args, **kwargs)

    # Use attributes set by super().__init__()
    head_size = self.head_size
    num_kv_heads = self.num_kv_heads

    # TQ4 compression primitives (deterministic from seed, shared across layers)
    quantizer = TurboQuantMSE(head_size, TQ4_BITS, seed=TQ4_SEED)
    self._tq4_rotation = quantizer.rotation  # (D, D) fp32
    self._tq4_centroids = quantizer.codebook.centroids  # (16,) fp32
    self._tq4_boundaries = quantizer.codebook.boundaries  # (15,) fp32
    # Pre-split rotation.T for fused compress kernel (contiguous loads)
    rot_t = quantizer.rotation.T.contiguous()
    self._tq4_rot_T_even = rot_t[:, 0::2].contiguous()  # (D, D//2) fp32
    self._tq4_rot_T_odd = rot_t[:, 1::2].contiguous()  # (D, D//2) fp32
    self._tq4_on_device = False

    # Byte layout offsets within the last dimension of the packed cache.
    # Layout: [K_indices(H*D/2) | K_norms(H*4) | V_indices(H*D/2) | V_norms(H*4)]
    half_D = head_size // 2
    self._half_D = half_D
    self._k_idx_end = num_kv_heads * half_D
    self._k_norm_end = self._k_idx_end + num_kv_heads * TQ4_NORM_BYTES
    self._v_idx_end = self._k_norm_end + num_kv_heads * half_D
    self._total_bytes = self._v_idx_end + num_kv_heads * TQ4_NORM_BYTES

    logger.info(
        "TQ4AttentionImpl: %d KV heads, head_size=%d, "
        "%d bytes/token (%.2fx compression vs FP16)",
        num_kv_heads,
        head_size,
        self._total_bytes,
        (2 * num_kv_heads * head_size * 2) / self._total_bytes,
    )
Functions
forward
forward(
    layer,
    query,
    key,
    value,
    kv_cache,
    attn_metadata,
    output=None,
    output_scale=None,
    output_block_scale=None,
)

TQ4 attention: compress -> store -> pre-rotate Q -> decompress -> FA -> post-rotate.

Phase 3c.8: Uses fused Triton decompress (no rotation). The rotation is applied to Q before attention and to the output after, saving O(cache_len) matmuls per decode step.

Source code in src/turboquant_vllm/vllm/tq4_backend.py
def forward(
    self,
    layer,
    query,
    key,
    value,
    kv_cache,
    attn_metadata,
    output=None,
    output_scale=None,
    output_block_scale=None,
):
    """TQ4 attention: compress -> store -> pre-rotate Q -> decompress -> FA -> post-rotate.

    Phase 3c.8: Uses fused Triton decompress (no rotation). The
    rotation is applied to Q before attention and to the output
    after, saving O(cache_len) matmuls per decode step.
    """
    assert output is not None

    if output_scale is not None or output_block_scale is not None:
        raise NotImplementedError(
            "Fused output quantization is not supported with TQ4 backend"
        )

    # Profiling mode
    if attn_metadata is None:
        output.zero_()
        return output

    # Encoder attention: no TQ4, delegate to parent
    # (VIT uses a separate backend, but guard just in case)
    from vllm.v1.attention.backend import AttentionType

    if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
        return self._forward_encoder_attention(
            query[: attn_metadata.num_actual_tokens],
            key[: attn_metadata.num_actual_tokens],
            value[: attn_metadata.num_actual_tokens],
            output[: attn_metadata.num_actual_tokens],
            attn_metadata,
            layer,
        )

    num_actual_tokens = attn_metadata.num_actual_tokens

    # Step 1: Compress and store new K/V tokens
    if kv_cache is not None and key is not None and value is not None:
        self._ensure_device(query.device)
        self._compress_and_store(key, value, kv_cache, attn_metadata.slot_mapping)

    # Step 2: Pre-rotate Q by Pi^T (O(num_actual_tokens), not O(cache_len))
    self._ensure_device(query.device)
    q_slice = query[:num_actual_tokens]
    q_rot = (q_slice.float() @ self._tq4_rotation.T).to(q_slice.dtype)

    # Step 3: Decompress full cache (Triton fused, skip rotation)
    key_cache, value_cache = self._decompress_cache(
        kv_cache,
        query.dtype,
        apply_rotation=False,
    )

    # Step 4: Run Flash Attention with rotated Q and rotated KV
    from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func

    if attn_metadata.use_cascade:
        raise NotImplementedError("TQ4 does not yet support cascade attention")

    descale_shape = (
        attn_metadata.query_start_loc.shape[0] - 1,
        self.num_kv_heads,
    )
    q_descale = layer._q_scale.expand(descale_shape)
    k_descale = layer._k_scale.expand(descale_shape)
    v_descale = layer._v_scale.expand(descale_shape)

    flash_attn_varlen_func(
        q=q_rot,
        k=key_cache,
        v=value_cache,
        out=output[:num_actual_tokens],
        cu_seqlens_q=attn_metadata.query_start_loc,
        max_seqlen_q=attn_metadata.max_query_len,
        seqused_k=attn_metadata.seq_lens,
        max_seqlen_k=attn_metadata.max_seq_len,
        softmax_scale=self.scale,
        causal=attn_metadata.causal,
        alibi_slopes=self.alibi_slopes,  # ty: ignore[invalid-argument-type]
        window_size=list(self.sliding_window)
        if self.sliding_window is not None
        else None,
        block_table=attn_metadata.block_table,
        softcap=self.logits_soft_cap,
        scheduler_metadata=attn_metadata.scheduler_metadata,
        fa_version=self.vllm_flash_attn_version,  # ty: ignore[invalid-argument-type]
        q_descale=q_descale,
        k_descale=k_descale,
        v_descale=v_descale,
        num_splits=attn_metadata.max_num_splits,
        s_aux=self.sinks,
    )

    # Step 5: Post-rotate output by Pi (undo rotation space)
    out_slice = output[:num_actual_tokens]
    output[:num_actual_tokens] = (out_slice.float() @ self._tq4_rotation).to(
        out_slice.dtype
    )

    return output

Functions

register_tq4_backend

register_tq4_backend() -> None

Register TQ4 as the CUSTOM attention backend.

In addition to registering the backend class, this monkey-patches Attention.get_kv_cache_spec so that decoder attention layers return :class:TQ4FullAttentionSpec (with dtype=torch.uint8 and TQ4-sized pages) instead of the standard FullAttentionSpec.

Called automatically by the vllm.general_plugins entry point, or manually before starting vLLM::

from turboquant_vllm.vllm import register_tq4_backend

register_tq4_backend()
# then start vLLM with --attention-backend CUSTOM
Source code in src/turboquant_vllm/vllm/tq4_backend.py
def register_tq4_backend() -> None:
    """Register TQ4 as the CUSTOM attention backend.

    In addition to registering the backend class, this monkey-patches
    ``Attention.get_kv_cache_spec`` so that decoder attention layers
    return :class:`TQ4FullAttentionSpec` (with ``dtype=torch.uint8``
    and TQ4-sized pages) instead of the standard ``FullAttentionSpec``.

    Called automatically by the ``vllm.general_plugins`` entry point,
    or manually before starting vLLM::

        from turboquant_vllm.vllm import register_tq4_backend

        register_tq4_backend()
        # then start vLLM with --attention-backend CUSTOM
    """
    global _original_get_kv_cache_spec  # noqa: PLW0603

    register_backend(
        AttentionBackendEnum.CUSTOM,
        "turboquant_vllm.vllm.tq4_backend.TQ4AttentionBackend",
    )

    # Register TQ4FullAttentionSpec in the KV cache manager mapping.
    # vLLM uses exact type() match, not isinstance(), so subclasses
    # of FullAttentionSpec must be explicitly added.
    from vllm.v1.core.single_type_kv_cache_manager import spec_manager_map

    if TQ4FullAttentionSpec not in spec_manager_map:
        spec_manager_map[TQ4FullAttentionSpec] = spec_manager_map[FullAttentionSpec]

    # Monkey-patch Attention.get_kv_cache_spec to return TQ4 spec
    from vllm.model_executor.layers.attention.attention import Attention

    if _original_get_kv_cache_spec is None:
        _original_get_kv_cache_spec = Attention.get_kv_cache_spec

    def _tq4_get_kv_cache_spec(self, vllm_config):
        spec = _original_get_kv_cache_spec(self, vllm_config)
        if isinstance(spec, FullAttentionSpec) and not isinstance(
            spec, TQ4FullAttentionSpec
        ):
            kwargs = {f.name: getattr(spec, f.name) for f in dc_fields(spec)}
            kwargs["dtype"] = torch.uint8
            return TQ4FullAttentionSpec(**kwargs)
        return spec

    Attention.get_kv_cache_spec = _tq4_get_kv_cache_spec  # ty: ignore[invalid-assignment]
    logger.info("TQ4 attention backend registered as CUSTOM (packed cache)")