Skip to content

tq4_backend

turboquant_vllm.vllm.tq4_backend

TQ4 compressed KV cache attention backend for vLLM.

Phase 3c: Packed TQ4 cache layout with real VRAM savings.

The KV cache is stored as uint8 bytes in a packed TQ4 format (68 bytes per token per head per K/V = 136 bytes total vs 512 bytes FP16 = 3.76x compression). Buffer allocation uses a custom TQ4FullAttentionSpec that overrides page_size_bytes so the block allocator provisions 3.76x more blocks in the same VRAM budget. Each forward() call decompresses the relevant blocks to FP16 and delegates to Flash Attention.

Implementation phases

3a (done): Passthrough skeleton -- validated plugin wiring. 3b (done): Compress-decompress round-trip in standard FP16 cache. 3c (this): Packed uint8 cache with real VRAM savings. 3d: Production benchmark against vLLM baseline.

Classes

TQ4FullAttentionSpec dataclass

TQ4FullAttentionSpec()

Bases: FullAttentionSpec

KV cache spec with TQ4 packed page size.

Overrides real_page_size_bytes so the block allocator provisions buffers sized for the packed TQ4 format. Supports asymmetric K/V bit-widths via TQ4_K_BITS / TQ4_V_BITS env vars. Follows the same pattern as MLAAttentionSpec which overrides page size for the 656-byte FlashMLA format.

TQ4MetadataBuilder

Bases: FlashAttentionMetadataBuilder

Metadata builder for TQ4 with conditional CUDA graph support.

CUDA graphs are supported for single-token decode only when the fused paged kernel is available; otherwise CG support is NEVER (the paged decompress path has dynamic allocations). Inherits all metadata-building logic from Flash Attention; only the CUDA graph support level differs.

Functions
get_cudagraph_support classmethod
get_cudagraph_support(vllm_config: object, kv_cache_spec: object) -> AttentionCGSupport

Report CUDA graph support: single-token decode when fused available.

When fused paged decode is available, decode goes through _fused_decode_path (CG-safe). Otherwise, decode uses _decompress_cache_paged which has 10+ non-CG-safe operations (torch.unique, boolean indexing, dynamic allocations).

Source code in src/turboquant_vllm/vllm/tq4_backend.py
@classmethod
def get_cudagraph_support(
    cls,
    vllm_config: object,
    kv_cache_spec: object,
) -> AttentionCGSupport:
    """Report CUDA graph support: single-token decode when fused available.

    When fused paged decode is available, decode goes through
    ``_fused_decode_path`` (CG-safe).  Otherwise, decode uses
    ``_decompress_cache_paged`` which has 10+ non-CG-safe operations
    (torch.unique, boolean indexing, dynamic allocations).
    """
    from vllm.v1.attention.backend import AttentionCGSupport

    k_bits, v_bits = _parse_kv_bits_env()
    if (
        _parse_fused_paged_env()
        and _fused_paged_kernel_available
        and k_bits == v_bits
    ):
        return AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
    return AttentionCGSupport.NEVER

TQ4AttentionBackend

Bases: FlashAttentionBackend

TQ4 compressed KV cache attention backend.

Phase 3c: packed uint8 cache layout with real VRAM savings. The cache stores nibble-packed TQ4 indices + fp32 norms as raw bytes. get_kv_cache_shape() returns a 3D (NB, BS, bytes_per_token) layout matching the packed format.

Functions
supports_mm_prefix classmethod
supports_mm_prefix() -> bool

Required for VLMs like Molmo2 with bidirectional visual tokens.

Source code in src/turboquant_vllm/vllm/tq4_backend.py
@classmethod
def supports_mm_prefix(cls) -> bool:
    """Required for VLMs like Molmo2 with bidirectional visual tokens."""
    return True
get_name staticmethod
get_name() -> str

Must return "CUSTOM" to match AttentionBackendEnum.CUSTOM.

Source code in src/turboquant_vllm/vllm/tq4_backend.py
@staticmethod
def get_name() -> str:
    """Must return ``"CUSTOM"`` to match ``AttentionBackendEnum.CUSTOM``."""
    return "CUSTOM"
get_impl_cls staticmethod
get_impl_cls() -> type[AttentionImplBase]

Return :class:TQ4AttentionImpl.

Source code in src/turboquant_vllm/vllm/tq4_backend.py
@staticmethod
def get_impl_cls() -> type[AttentionImplBase]:
    """Return :class:`TQ4AttentionImpl`."""
    return TQ4AttentionImpl
get_builder_cls staticmethod
get_builder_cls() -> type[AttentionMetadataBuilder]

Return :class:TQ4MetadataBuilder for CUDA graph support.

Source code in src/turboquant_vllm/vllm/tq4_backend.py
@staticmethod
def get_builder_cls() -> type[AttentionMetadataBuilder]:
    """Return :class:`TQ4MetadataBuilder` for CUDA graph support."""
    return TQ4MetadataBuilder
get_kv_cache_shape staticmethod
get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]

Packed TQ4 cache: (num_blocks, block_size, padded_bytes).

The last dimension packs K and V data for all heads as raw bytes with padding for hybrid model page alignment. Only the first num_kv_heads * _tq4_bytes_per_token_kv(head_size) bytes per token contain packed data; trailing bytes are unused padding.

Source code in src/turboquant_vllm/vllm/tq4_backend.py
@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
    """Packed TQ4 cache: ``(num_blocks, block_size, padded_bytes)``.

    The last dimension packs K and V data for all heads as raw bytes
    with padding for hybrid model page alignment. Only the first
    ``num_kv_heads * _tq4_bytes_per_token_kv(head_size)`` bytes per
    token contain packed data; trailing bytes are unused padding.
    """
    total_bytes = num_kv_heads * _padded_slot_bytes(head_size)
    return (num_blocks, block_size, total_bytes)
get_kv_cache_stride_order staticmethod
get_kv_cache_stride_order(
    include_num_layers_dimension: bool = False,
) -> tuple[int, ...]

Raise to trigger identity fallback in reshape.

The inherited FlashAttentionBackend returns a 5-element stride order for the standard (2, NB, BS, H, D) shape. Our 3D packed layout (NB, BS, total_bytes) needs identity ordering. Raising NotImplementedError triggers the fallback in _reshape_kv_cache_tensors (same pattern as FlashMLA which does not implement this method at all).

Source code in src/turboquant_vllm/vllm/tq4_backend.py
@staticmethod
def get_kv_cache_stride_order(
    include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
    """Raise to trigger identity fallback in reshape.

    The inherited FlashAttentionBackend returns a 5-element stride
    order for the standard ``(2, NB, BS, H, D)`` shape. Our 3D
    packed layout ``(NB, BS, total_bytes)`` needs identity ordering.
    Raising ``NotImplementedError`` triggers the fallback in
    ``_reshape_kv_cache_tensors`` (same pattern as FlashMLA which
    does not implement this method at all).
    """
    raise NotImplementedError

TQ4AttentionImpl

TQ4AttentionImpl(*args, **kwargs)

Bases: FlashAttentionImpl

TQ4 attention: compress -> store -> decompress -> Flash Attention.

Phase 3c: stores packed TQ4 bytes in a uint8 cache for real VRAM savings. Each forward() call:

  1. Compresses incoming K/V tokens to TQ4 packed bytes.
  2. Scatter-writes packed bytes to the uint8 cache via slot_mapping.
  3. Decompresses the full cache to FP16 for Flash Attention.
  4. Calls flash_attn_varlen_func directly with the FP16 data.

Initialize TQ4 attention with compression primitives.

Source code in src/turboquant_vllm/vllm/tq4_backend.py
def __init__(self, *args, **kwargs) -> None:
    """Initialize TQ4 attention with compression primitives."""
    super().__init__(*args, **kwargs)

    # Use attributes set by super().__init__()
    head_size = self.head_size
    num_kv_heads = self.num_kv_heads

    # Resolve per-component bit-widths from env vars
    k_bits, v_bits = _parse_kv_bits_env()
    self._k_bits = k_bits
    self._v_bits = v_bits

    # TQ4 compression primitives (deterministic from seed, shared across layers)
    # Rotation matrix is dim-dependent only (not bits-dependent), so shared.
    k_quantizer = TurboQuantMSE(head_size, k_bits, seed=TQ4_SEED)
    v_quantizer = (
        k_quantizer
        if v_bits == k_bits
        else TurboQuantMSE(head_size, v_bits, seed=TQ4_SEED)
    )

    # Eagerly move primitives to the target device (D7 mod 5).
    # FlashAttentionImpl.__init__ doesn't expose device, but
    # vLLM's global config is available during model construction.
    from vllm.config import get_current_vllm_config_or_none

    vllm_config = get_current_vllm_config_or_none()
    device = (
        vllm_config.device_config.device
        if vllm_config is not None
        else torch.device("cpu")
    )

    # Shared rotation (dim-dependent only, same seed → identical matrix)
    self._tq4_rotation = k_quantizer.rotation.to(device)  # (D, D) fp32
    # Pre-split rotation.T for fused compress kernel (contiguous loads)
    rot_t = k_quantizer.rotation.T.contiguous()
    self._tq4_rot_T_even = rot_t[:, 0::2].contiguous().to(device)  # (D, D//2) fp32
    self._tq4_rot_T_odd = rot_t[:, 1::2].contiguous().to(device)  # (D, D//2) fp32

    # Per-component codebooks (may differ when k_bits != v_bits)
    self._k_centroids = k_quantizer.codebook.centroids.to(device)
    self._k_boundaries = k_quantizer.codebook.boundaries.to(device)
    self._v_centroids = v_quantizer.codebook.centroids.to(device)
    self._v_boundaries = v_quantizer.codebook.boundaries.to(device)

    # Byte layout offsets within the last dimension of the packed cache.
    # Layout: [K_indices | K_norms | V_indices | V_norms]
    # The Triton compress kernel always nibble-packs (head_dim // 2 bytes
    # per component), regardless of bit-width. Different codebook sizes
    # provide quality improvement, not storage savings in the vLLM path.
    k_idx_size = head_size // 2
    v_idx_size = head_size // 2
    self._k_idx_size = k_idx_size
    self._v_idx_size = v_idx_size
    self._k_idx_end = num_kv_heads * k_idx_size
    self._k_norm_end = self._k_idx_end + num_kv_heads * TQ4_NORM_BYTES
    self._v_idx_end = self._k_norm_end + num_kv_heads * v_idx_size
    self._total_bytes = self._v_idx_end + num_kv_heads * TQ4_NORM_BYTES

    # CUDA graph scratch buffers (D7 mod 2) — lazy-allocated on first
    # forward() from kv_cache.shape, which is stable for engine lifetime.
    # First forward runs during vLLM warmup, before graph capture.
    self._cg_buffers_ready = False

    # Fused paged decode feature gate (Story 6.3, AC 1+6).
    # Explicit opt-in via TQ4_USE_FUSED_PAGED env var AND successful
    # kernel import.  Disabled for asymmetric configs — fused kernel
    # uses a single codebook and cannot handle k_bits != v_bits.
    self._fused_paged_available = (
        _parse_fused_paged_env()
        and _fused_paged_kernel_available
        and k_bits == v_bits
    )

    # INT8 prefill gate (Story 6.4): requires fused decode gate + its own
    # env var + successful kernel import.
    self._int8_prefill_available = (
        self._fused_paged_available
        and _parse_int8_prefill_env()
        and _int8_prefill_kernel_available
    )

    # Buffer downsizing source: scheduler knows its own max prefill length.
    # Fallback 2048 matches vLLM's default max_num_batched_tokens for
    # chunked prefill.
    self._max_prefill_len = (
        vllm_config.scheduler_config.max_num_batched_tokens
        if vllm_config is not None
        else 2048
    )

    # Decode buffer bound: max_model_len caps decompress buffer instead
    # of full cache capacity.  Fallback 6144 matches Molmo2 default.
    self._max_model_len = (
        vllm_config.model_config.max_model_len if vllm_config is not None else 6144
    )

    logger.info(
        "TQ4AttentionImpl: %d KV heads, head_size=%d, k_bits=%d, v_bits=%d, "
        "%d bytes/token (%.2fx compression vs FP16)",
        num_kv_heads,
        head_size,
        k_bits,
        v_bits,
        self._total_bytes,
        (2 * num_kv_heads * head_size * 2) / self._total_bytes,
    )
    logger.info(
        "Fused paged TQ4 decode: %s",
        "enabled" if self._fused_paged_available else "disabled",
    )
    logger.info(
        "INT8 prefill path: %s",
        "enabled" if self._int8_prefill_available else "disabled",
    )
Functions
forward
forward(
    layer,
    query,
    key,
    value,
    kv_cache,
    attn_metadata,
    output=None,
    output_scale=None,
    output_block_scale=None,
)

TQ4 attention: compress -> store -> pre-rotate Q -> decompress -> FA -> post-rotate.

Phase 3c.8: Uses fused Triton decompress (no rotation). The rotation is applied to Q before attention and to the output after, saving O(cache_len) matmuls per decode step.

Source code in src/turboquant_vllm/vllm/tq4_backend.py
def forward(
    self,
    layer,
    query,
    key,
    value,
    kv_cache,
    attn_metadata,
    output=None,
    output_scale=None,
    output_block_scale=None,
):
    """TQ4 attention: compress -> store -> pre-rotate Q -> decompress -> FA -> post-rotate.

    Phase 3c.8: Uses fused Triton decompress (no rotation). The
    rotation is applied to Q before attention and to the output
    after, saving O(cache_len) matmuls per decode step.
    """
    assert output is not None

    if output_scale is not None or output_block_scale is not None:
        raise NotImplementedError(
            "Fused output quantization is not supported with TQ4 backend"
        )

    # Profiling mode
    if attn_metadata is None:
        output.zero_()
        return output

    # Warmup with no cache allocated yet
    if kv_cache is None:
        output.zero_()
        return output

    # Encoder attention: no TQ4, delegate to parent
    # (VIT uses a separate backend, but guard just in case)
    from vllm.v1.attention.backend import AttentionType

    if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
        return self._forward_encoder_attention(
            query[: attn_metadata.num_actual_tokens],
            key[: attn_metadata.num_actual_tokens],
            value[: attn_metadata.num_actual_tokens],
            output[: attn_metadata.num_actual_tokens],
            attn_metadata,
            layer,
        )

    num_actual_tokens = attn_metadata.num_actual_tokens

    # Lazy-init CUDA graph buffers on first forward (during warmup)
    if not self._cg_buffers_ready and kv_cache is not None:
        self._init_cg_buffers(kv_cache, compute_dtype=query.dtype)

    # Steps 1-3: compress, rotate Q, decompress (decode vs prefill path)
    is_decode = self._cg_buffers_ready and num_actual_tokens == 1

    # Fused paged decode (Story 6.3): single kernel replaces
    # decompress + FlashAttn + post-rotate for decode steps.
    if self._fused_paged_available and is_decode:
        return self._fused_decode_path(
            query, key, value, kv_cache, attn_metadata, output
        )

    # INT8 prefill (Story 6.4): IMMA tensor core Q@K^T for prefill.
    # Guard: kernel is single-sequence only; fall back for multi-sequence
    # batches (vLLM scheduler may combine multiple requests).
    if (
        self._int8_prefill_available
        and not is_decode
        and attn_metadata.seq_lens.shape[0] == 1
    ):
        return self._int8_prefill_path(
            query, key, value, kv_cache, attn_metadata, output
        )

    if is_decode:
        q_rot, key_cache, value_cache, fa_block_table = self._tq4_decode(
            query, key, value, kv_cache, attn_metadata
        )
    else:
        q_rot, key_cache, value_cache, fa_block_table = self._tq4_prefill(
            query, key, value, kv_cache, attn_metadata
        )

    # Step 4: Run Flash Attention with rotated Q and rotated KV
    from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func

    if attn_metadata.use_cascade:
        raise NotImplementedError("TQ4 does not yet support cascade attention")

    descale_shape = (
        attn_metadata.query_start_loc.shape[0] - 1,
        self.num_kv_heads,
    )
    q_descale = layer._q_scale.expand(descale_shape)
    k_descale = layer._k_scale.expand(descale_shape)
    v_descale = layer._v_scale.expand(descale_shape)

    flash_attn_varlen_func(
        q=q_rot,
        k=key_cache,
        v=value_cache,
        out=output[:num_actual_tokens],
        cu_seqlens_q=attn_metadata.query_start_loc,
        max_seqlen_q=attn_metadata.max_query_len,
        seqused_k=attn_metadata.seq_lens,
        max_seqlen_k=attn_metadata.max_seq_len,
        softmax_scale=self.scale,
        causal=attn_metadata.causal,
        alibi_slopes=self.alibi_slopes,
        window_size=list(self.sliding_window)
        if self.sliding_window is not None
        else None,
        block_table=fa_block_table,
        softcap=self.logits_soft_cap,
        scheduler_metadata=attn_metadata.scheduler_metadata,
        fa_version=self.vllm_flash_attn_version,
        q_descale=q_descale,
        k_descale=k_descale,
        v_descale=v_descale,
        num_splits=attn_metadata.max_num_splits,
        s_aux=self.sinks,
    )

    # Step 5: Post-rotate output by Pi (undo rotation space)
    out_slice = output[:num_actual_tokens]
    output[:num_actual_tokens] = (out_slice.float() @ self._tq4_rotation).to(
        out_slice.dtype
    )

    return output

Functions

register_tq4_backend

register_tq4_backend() -> None

Register TQ4 as the CUSTOM attention backend.

In addition to registering the backend class, this monkey-patches Attention.get_kv_cache_spec so that decoder attention layers return :class:TQ4FullAttentionSpec (with dtype=torch.uint8 and TQ4-sized pages) instead of the standard FullAttentionSpec.

Called automatically by the vllm.general_plugins entry point, or manually before starting vLLM::

from turboquant_vllm.vllm import register_tq4_backend

register_tq4_backend()
# then start vLLM with --attention-backend CUSTOM
Source code in src/turboquant_vllm/vllm/tq4_backend.py
def register_tq4_backend() -> None:
    """Register TQ4 as the CUSTOM attention backend.

    In addition to registering the backend class, this monkey-patches
    ``Attention.get_kv_cache_spec`` so that decoder attention layers
    return :class:`TQ4FullAttentionSpec` (with ``dtype=torch.uint8``
    and TQ4-sized pages) instead of the standard ``FullAttentionSpec``.

    Called automatically by the ``vllm.general_plugins`` entry point,
    or manually before starting vLLM::

        from turboquant_vllm.vllm import register_tq4_backend

        register_tq4_backend()
        # then start vLLM with --attention-backend CUSTOM
    """
    global _original_get_kv_cache_spec  # noqa: PLW0603

    register_backend(
        AttentionBackendEnum.CUSTOM,
        "turboquant_vllm.vllm.tq4_backend.TQ4AttentionBackend",
    )

    # Register TQ4FullAttentionSpec in the KV cache manager mapping.
    # vLLM uses exact type() match, not isinstance(), so subclasses
    # of FullAttentionSpec must be explicitly added.
    from vllm.v1.core.single_type_kv_cache_manager import spec_manager_map

    if TQ4FullAttentionSpec not in spec_manager_map:
        spec_manager_map[TQ4FullAttentionSpec] = spec_manager_map[FullAttentionSpec]

    # Monkey-patch Attention.get_kv_cache_spec to return TQ4 spec
    from vllm.model_executor.layers.attention.attention import Attention

    if _original_get_kv_cache_spec is None:
        _original_get_kv_cache_spec = Attention.get_kv_cache_spec

    def _tq4_get_kv_cache_spec(self, vllm_config):
        spec = _original_get_kv_cache_spec(self, vllm_config)
        if isinstance(spec, FullAttentionSpec) and not isinstance(
            spec, TQ4FullAttentionSpec
        ):
            kwargs = {f.name: getattr(spec, f.name) for f in dc_fields(spec)}
            kwargs["dtype"] = torch.uint8
            return TQ4FullAttentionSpec(**kwargs)
        return spec

    Attention.get_kv_cache_spec = _tq4_get_kv_cache_spec
    logger.info("TQ4 attention backend registered as CUSTOM (packed cache)")