Skip to content

fused_paged_tq4_int8_prefill

turboquant_vllm.triton.fused_paged_tq4_int8_prefill

Fused paged TQ4 INT8 prefill attention -- IMMA tensor core path.

Phase 3b of the D9 kernel roadmap. This kernel reads TQ4-compressed blocks directly from vLLM's paged block table, decompresses in SRAM, re-quantizes Q and K to INT8 per-tile, and computes Q@K^T via IMMA tensor cores (mma.sync.aligned.m16n8k32.s8.s8.s32). P@V accumulation stays in FP16 HMMA.

Designed for prefill (BLOCK_M=64, compute-bound) where INT8 tensor cores provide 1.3-2x speedup over FP16. Decode uses the separate FP16 kernel (fused_paged_tq4_attention.py, BLOCK_M=1, memory-bound where INT8 provides no benefit).

The kernel operates in rotated space: caller pre-rotates Q by Pi^T and post-rotates the output by Pi. QJL correction is deferred (placeholders in signature, compiled out via constexpr).

Autotune: 4 configs (BLOCK_N in {16, 32} x stages=1 x warps {4, 8}). Stages {2, 3} dropped after Experiment 021 profiling showed them 3-5x slower than single-stage at 1K-2K prefill on RTX 4090.

Attributes:

Name Type Description
fused_paged_tq4_int8_prefill Tensor

Python wrapper that pre-rotates Q, launches the INT8 prefill kernel, and post-rotates the output.

Examples:

from turboquant_vllm.triton.fused_paged_tq4_int8_prefill import (
    fused_paged_tq4_int8_prefill,
)

out = fused_paged_tq4_int8_prefill(
    q,
    kv_cache,
    block_table,
    seq_lens,
    centroids,
    rotation,
    num_kv_heads=4,
    head_dim=128,
    block_size=16,
)
See Also

:mod:turboquant_vllm.triton.fused_paged_tq4_attention: FP16 decode kernel (BLOCK_M=1). :mod:turboquant_vllm.triton.tq4_decompress: Standalone decompress.

Functions

fused_paged_tq4_int8_prefill

fused_paged_tq4_int8_prefill(
    q: Tensor,
    kv_cache: Tensor,
    block_table: Tensor,
    seq_lens: Tensor,
    centroids: Tensor,
    rotation: Tensor,
    num_kv_heads: int,
    head_dim: int,
    block_size: int,
    sm_scale: float | None = None,
    out: Tensor | None = None,
) -> Tensor

Fused paged TQ4 INT8 prefill attention.

Pre-rotates Q by rotation^T, launches the INT8 prefill kernel that decompresses TQ4 blocks in-tile and uses IMMA tensor cores for Q@K^T, then post-rotates the output to return to original space.

Designed for prefill (multiple queries per sequence, compute-bound). For decode (single query, memory-bound), use :func:fused_paged_tq4_decode instead.

Parameters:

Name Type Description Default
q Tensor

Query [num_tokens, H_Q, head_dim] fp16/bf16.

required
kv_cache Tensor

Packed paged cache [num_blocks, block_size, total_bytes] uint8.

required
block_table Tensor

Page table [1, max_num_blocks_per_seq] int32. Must have exactly one sequence (single-sequence kernel).

required
seq_lens Tensor

Sequence lengths [1] int32. Must have exactly one entry (single-sequence kernel).

required
centroids Tensor

TQ4 codebook [16] fp32.

required
rotation Tensor

Orthogonal rotation [head_dim, head_dim] fp32.

required
num_kv_heads int

Number of KV heads.

required
head_dim int

Head dimension (e.g. 128).

required
block_size int

vLLM page size (tokens per block).

required
sm_scale float | None

Softmax scale. Defaults to 1 / sqrt(head_dim).

None
out Tensor | None

Optional pre-allocated output [num_tokens, H_Q, head_dim].

None

Returns:

Type Description
Tensor

Attention output [num_tokens, H_Q, head_dim] in original space.

Raises:

Type Description
ValueError

If seq_lens or block_table contain more than one sequence (kernel hardcodes seq_id=0).

Source code in src/turboquant_vllm/triton/fused_paged_tq4_int8_prefill.py
def fused_paged_tq4_int8_prefill(
    q: torch.Tensor,
    kv_cache: torch.Tensor,
    block_table: torch.Tensor,
    seq_lens: torch.Tensor,
    centroids: torch.Tensor,
    rotation: torch.Tensor,
    num_kv_heads: int,
    head_dim: int,
    block_size: int,
    sm_scale: float | None = None,
    out: torch.Tensor | None = None,
) -> torch.Tensor:
    """Fused paged TQ4 INT8 prefill attention.

    Pre-rotates Q by ``rotation^T``, launches the INT8 prefill kernel
    that decompresses TQ4 blocks in-tile and uses IMMA tensor cores for
    Q@K^T, then post-rotates the output to return to original space.

    Designed for prefill (multiple queries per sequence, compute-bound).
    For decode (single query, memory-bound), use
    :func:`fused_paged_tq4_decode` instead.

    Args:
        q: Query ``[num_tokens, H_Q, head_dim]`` fp16/bf16.
        kv_cache: Packed paged cache ``[num_blocks, block_size, total_bytes]``
            uint8.
        block_table: Page table ``[1, max_num_blocks_per_seq]`` int32.
            Must have exactly one sequence (single-sequence kernel).
        seq_lens: Sequence lengths ``[1]`` int32.  Must have exactly one
            entry (single-sequence kernel).
        centroids: TQ4 codebook ``[16]`` fp32.
        rotation: Orthogonal rotation ``[head_dim, head_dim]`` fp32.
        num_kv_heads: Number of KV heads.
        head_dim: Head dimension (e.g. 128).
        block_size: vLLM page size (tokens per block).
        sm_scale: Softmax scale.  Defaults to ``1 / sqrt(head_dim)``.
        out: Optional pre-allocated output ``[num_tokens, H_Q, head_dim]``.

    Returns:
        Attention output ``[num_tokens, H_Q, head_dim]`` in original space.

    Raises:
        ValueError: If ``seq_lens`` or ``block_table`` contain more than
            one sequence (kernel hardcodes ``seq_id=0``).
    """
    num_tokens, H_Q, D = q.shape

    assert D == head_dim
    assert H_Q % num_kv_heads == 0
    assert kv_cache.dtype == torch.uint8
    assert block_table.dtype == torch.int32
    assert seq_lens.dtype == torch.int32

    # Kernel hardcodes seq_id=0; reject multi-sequence inputs at the API boundary.
    if seq_lens.numel() != 1:
        raise ValueError(
            f"fused_paged_tq4_int8_prefill supports only a single sequence, "
            f"but got seq_lens.numel() == {seq_lens.numel()}."
        )
    if block_table.shape[0] != 1:
        raise ValueError(
            "fused_paged_tq4_int8_prefill supports only a single sequence "
            f"(block_table.shape[0] must be 1), but got {block_table.shape[0]}."
        )

    if sm_scale is None:
        sm_scale = 1.0 / math.sqrt(head_dim)

    half_D = head_dim // 2

    # Byte layout constexprs (same as decode kernel)
    k_norm_offset = num_kv_heads * half_D
    v_idx_offset = k_norm_offset + num_kv_heads * 4
    v_norm_offset = v_idx_offset + num_kv_heads * half_D

    # Pre-rotate Q by Pi^T
    q_rot = torch.matmul(q.float(), rotation.T).to(q.dtype)

    out_rot = torch.empty_like(q)

    # INT8 / QJL placeholders (compiled out)
    dummy = torch.empty(0, device=q.device)

    grid = (triton.cdiv(num_tokens, 64), H_Q)

    _fused_paged_tq4_int8_prefill_kernel[grid](
        q_rot,
        kv_cache,
        block_table,
        seq_lens,
        centroids,
        out_rot,
        dummy,  # Q_scale
        dummy,  # QJL_S
        dummy,  # QJL_signs
        dummy,  # QJL_residual_norms
        q_rot.stride(0),
        q_rot.stride(1),
        q_rot.stride(2),
        kv_cache.stride(0),
        kv_cache.stride(1),
        block_table.stride(0),
        block_table.stride(1),
        out_rot.stride(0),
        out_rot.stride(1),
        out_rot.stride(2),
        sm_scale=sm_scale,
        H_Q=H_Q,
        H_KV=num_kv_heads,
        HEAD_DIM=head_dim,
        BLOCK_SIZE=block_size,
        HALF_D=half_D,
        K_NORM_OFFSET=k_norm_offset,
        V_IDX_OFFSET=v_idx_offset,
        V_NORM_OFFSET=v_norm_offset,
        NUM_TOKENS=num_tokens,
        USE_INT8_QK=True,
        QJL_DIM=0,
    )

    # Post-rotate: convert from rotated space back to original space
    result = torch.matmul(out_rot.float(), rotation).to(q.dtype)
    if out is not None:
        out.copy_(result)
        return out
    return result