Skip to content

fused_paged_tq4_attention

turboquant_vllm.triton.fused_paged_tq4_attention

Fused paged TQ4 decode attention -- decompresses directly from page table.

Phase 3a of the D9 kernel roadmap. This kernel reads TQ4-compressed blocks directly from vLLM's paged block table, decompresses in SRAM (nibble unpack -> centroid gather -> norm scale), and computes FP16 Q@K^T with online softmax in a single fused pass. No HBM writes of decompressed cache -- HBM traffic drops from 1,160 to 136 bytes/token (8.5x reduction).

The kernel operates entirely in rotated space. The caller pre-rotates Q by Pi^T and post-rotates the output by Pi. Decompression does NOT apply rotation (matching tq4_decompress.py).

Scope: FP16/BF16 Q decode path only (USE_INT8_QK=False). INT8 path is Story 6.4. Placeholder parameters are included for forward compatibility but compiled out by the constexpr switch.

Autotune: 8 configs (BLOCK_N in {32, 64} x stages {2,3} x warps {4,8}). BLOCK_N=16 dropped after Experiment 020 profiling showed it consistently slowest across 1K-32K context on RTX 4090.

Attributes:

Name Type Description
fused_paged_tq4_decode Tensor

Python wrapper that pre-rotates Q, launches the fused paged kernel, and post-rotates the output.

Examples:

from turboquant_vllm.triton.fused_paged_tq4_attention import (
    fused_paged_tq4_decode,
)

out = fused_paged_tq4_decode(
    q,
    kv_cache,
    block_table,
    seq_lens,
    centroids,
    rotation,
    num_kv_heads=4,
    head_dim=128,
    block_size=16,
)
See Also

:mod:turboquant_vllm.triton.flash_attention_tq4_kv: Contiguous (non-paged) reference kernel -- correctness baseline. :mod:turboquant_vllm.triton.tq4_decompress: Standalone decompress.

Functions

fused_paged_tq4_decode

fused_paged_tq4_decode(
    q: Tensor,
    kv_cache: Tensor,
    block_table: Tensor,
    seq_lens: Tensor,
    centroids: Tensor,
    rotation: Tensor,
    num_kv_heads: int,
    head_dim: int,
    block_size: int,
    sm_scale: float | None = None,
    out: Tensor | None = None,
) -> Tensor

Fused paged TQ4 decode attention.

Pre-rotates Q by rotation^T, launches the fused paged kernel that decompresses TQ4 blocks in-tile from the page table, then post-rotates the output by rotation to return to original space.

Parameters:

Name Type Description Default
q Tensor

Query [num_seqs, H_Q, head_dim] fp16/bf16 (one token per seq).

required
kv_cache Tensor

Packed paged cache [num_blocks, block_size, total_bytes] uint8.

required
block_table Tensor

Page table [num_seqs, max_num_blocks_per_seq] int32.

required
seq_lens Tensor

Sequence lengths [num_seqs] int32.

required
centroids Tensor

TQ4 codebook [16] fp32.

required
rotation Tensor

Orthogonal rotation [head_dim, head_dim] fp32.

required
num_kv_heads int

Number of KV heads.

required
head_dim int

Head dimension (e.g. 128).

required
block_size int

vLLM page size (tokens per block).

required
sm_scale float | None

Softmax scale. Defaults to 1 / sqrt(head_dim).

None
out Tensor | None

Optional pre-allocated output [num_seqs, H_Q, head_dim]. When provided, the final post-rotated result is copied into this buffer and returned. A private scratch buffer is used internally for the kernel's rotated-space output.

None

Returns:

Type Description
Tensor

Attention output [num_seqs, H_Q, head_dim] in original space.

Tensor

When out is provided, returns out with the result written

Tensor

in-place.

Note

INT8 placeholder parameters (Q_scale, QJL_S, QJL_signs, QJL_residual_norms) should be passed as None/zeros when USE_INT8_QK=False (the default for Phase 3a).

Source code in src/turboquant_vllm/triton/fused_paged_tq4_attention.py
def fused_paged_tq4_decode(
    q: torch.Tensor,
    kv_cache: torch.Tensor,
    block_table: torch.Tensor,
    seq_lens: torch.Tensor,
    centroids: torch.Tensor,
    rotation: torch.Tensor,
    num_kv_heads: int,
    head_dim: int,
    block_size: int,
    sm_scale: float | None = None,
    out: torch.Tensor | None = None,
) -> torch.Tensor:
    """Fused paged TQ4 decode attention.

    Pre-rotates Q by ``rotation^T``, launches the fused paged kernel
    that decompresses TQ4 blocks in-tile from the page table, then
    post-rotates the output by ``rotation`` to return to original space.

    Args:
        q: Query ``[num_seqs, H_Q, head_dim]`` fp16/bf16 (one token per seq).
        kv_cache: Packed paged cache ``[num_blocks, block_size, total_bytes]``
            uint8.
        block_table: Page table ``[num_seqs, max_num_blocks_per_seq]`` int32.
        seq_lens: Sequence lengths ``[num_seqs]`` int32.
        centroids: TQ4 codebook ``[16]`` fp32.
        rotation: Orthogonal rotation ``[head_dim, head_dim]`` fp32.
        num_kv_heads: Number of KV heads.
        head_dim: Head dimension (e.g. 128).
        block_size: vLLM page size (tokens per block).
        sm_scale: Softmax scale.  Defaults to ``1 / sqrt(head_dim)``.
        out: Optional pre-allocated output ``[num_seqs, H_Q, head_dim]``.
            When provided, the final post-rotated result is copied into
            this buffer and returned.  A private scratch buffer is used
            internally for the kernel's rotated-space output.

    Returns:
        Attention output ``[num_seqs, H_Q, head_dim]`` in original space.
        When ``out`` is provided, returns ``out`` with the result written
        in-place.

    Note:
        INT8 placeholder parameters (``Q_scale``, ``QJL_S``, ``QJL_signs``,
        ``QJL_residual_norms``) should be passed as ``None``/zeros when
        ``USE_INT8_QK=False`` (the default for Phase 3a).
    """
    num_seqs, H_Q, D = q.shape

    assert D == head_dim
    assert H_Q % num_kv_heads == 0
    assert kv_cache.dtype == torch.uint8
    assert block_table.dtype == torch.int32
    assert seq_lens.dtype == torch.int32

    if sm_scale is None:
        sm_scale = 1.0 / math.sqrt(head_dim)

    half_D = head_dim // 2

    # Byte layout constexprs
    k_norm_offset = num_kv_heads * half_D
    v_idx_offset = k_norm_offset + num_kv_heads * 4
    v_norm_offset = v_idx_offset + num_kv_heads * half_D

    # Pre-rotate Q by Pi^T (O(num_seqs), not O(cache_len))
    q_rot = torch.matmul(q.float(), rotation.T).to(q.dtype)

    # Always use a private scratch buffer for the kernel's rotated-space output.
    # When ``out`` is provided, the final (post-rotated) result is copied into it.
    out_rot = torch.empty_like(q)

    # INT8 placeholders (unused, compiled out)
    dummy = torch.empty(0, device=q.device)

    grid = (num_seqs, H_Q)

    _fused_paged_tq4_decode_kernel[grid](
        q_rot,
        kv_cache,
        block_table,
        seq_lens,
        centroids,
        out_rot,
        dummy,  # Q_scale
        dummy,  # QJL_S
        dummy,  # QJL_signs
        dummy,  # QJL_residual_norms
        q_rot.stride(0),
        q_rot.stride(1),
        q_rot.stride(2),
        kv_cache.stride(0),
        kv_cache.stride(1),
        block_table.stride(0),
        block_table.stride(1),
        out_rot.stride(0),
        out_rot.stride(1),
        out_rot.stride(2),
        sm_scale=sm_scale,
        H_Q=H_Q,
        H_KV=num_kv_heads,
        HEAD_DIM=head_dim,
        BLOCK_SIZE=block_size,
        HALF_D=half_D,
        K_NORM_OFFSET=k_norm_offset,
        V_IDX_OFFSET=v_idx_offset,
        V_NORM_OFFSET=v_norm_offset,
        USE_INT8_QK=False,
        QJL_DIM=0,
    )

    # Post-rotate: convert from rotated space back to original space
    result = torch.matmul(out_rot.float(), rotation).to(q.dtype)
    if out is not None:
        out.copy_(result)
        return out
    return result