Skip to content

fused_qk_attention

turboquant_vllm.triton.fused_qk_attention

Fused Triton kernel for TQ4 nibble-packed attention scores.

Computes Q @ compressed_K^T directly from nibble-packed 4-bit indices without materializing decompressed key tensors. Reduces per-layer memory traffic from ~25 MB to ~3 MB during decode.

Key math::

<q, R^T @ centroids[idx]> = <R @ q, centroids[idx]>

Pre-rotate the query once (q_rot = q @ Pi_T), then the kernel does: score[s] = norm[s] * sum_d(q_rot[d] * centroids[idx[s,d]]) * scale

Based on the Dejan.ai TurboQuant Triton kernel, adapted for:

  • Nibble-packed 4-bit indices (two per uint8 byte)
  • fp32 norms (fp16 causes precision loss at 10K+ tokens)
  • Configurable GQA (tested with 4:1 ratio for Molmo2)

Attributes:

Name Type Description
fused_qk_scores Tensor

Python wrapper that launches the Triton kernel.

Examples:

scores = fused_qk_scores(
    q_rotated,  # [B, n_q_heads, q_len, head_dim]
    packed_indices,  # [B, n_kv_heads, kv_len, head_dim // 2] uint8
    norms,  # [B, n_kv_heads, kv_len] fp32
    centroids,  # [16] fp32 (for 4-bit)
    scale=1 / 128**0.5,
    n_q_heads=32,
    n_kv_heads=8,
)
See Also

:mod:turboquant_vllm.kv_cache: CompressedDynamicCache that produces the nibble-packed indices and fp32 norms consumed by this kernel. Dejan.ai TurboQuant blog <https://dejan.ai/blog/turboquant/>_: Original Triton kernel reference.

Functions

fused_qk_scores

fused_qk_scores(
    q_rotated: Tensor,
    packed_indices: Tensor,
    norms: Tensor,
    centroids: Tensor,
    scale: float,
    *,
    n_q_heads: int,
    n_kv_heads: int,
) -> Tensor

Compute attention scores from pre-rotated queries and nibble-packed keys.

The query must be pre-rotated by the TurboQuant rotation matrix (q_rot = q @ Pi_T) so the kernel avoids the expensive 128x128 rotation matmul in its inner loop.

Parameters:

Name Type Description Default
q_rotated Tensor

Pre-rotated query, shape (batch, n_q_heads, q_len, head_dim).

required
packed_indices Tensor

Nibble-packed 4-bit key indices, shape (batch, n_kv_heads, kv_len, head_dim // 2), dtype uint8.

required
norms Tensor

Key vector norms, shape (batch, n_kv_heads, kv_len), dtype float32.

required
centroids Tensor

Lloyd-Max centroid values, shape (n_levels,), dtype float32.

required
scale float

Attention scale factor (typically 1 / sqrt(head_dim)).

required
n_q_heads int

Number of query attention heads.

required
n_kv_heads int

Number of key-value attention heads.

required

Returns:

Type Description
Tensor

Attention scores, shape (batch, n_q_heads, q_len, kv_len).

Source code in src/turboquant_vllm/triton/fused_qk_attention.py
def fused_qk_scores(
    q_rotated: torch.Tensor,
    packed_indices: torch.Tensor,
    norms: torch.Tensor,
    centroids: torch.Tensor,
    scale: float,
    *,
    n_q_heads: int,
    n_kv_heads: int,
) -> torch.Tensor:
    """Compute attention scores from pre-rotated queries and nibble-packed keys.

    The query must be pre-rotated by the TurboQuant rotation matrix
    (``q_rot = q @ Pi_T``) so the kernel avoids the expensive 128x128
    rotation matmul in its inner loop.

    Args:
        q_rotated: Pre-rotated query, shape
            ``(batch, n_q_heads, q_len, head_dim)``.
        packed_indices: Nibble-packed 4-bit key indices, shape
            ``(batch, n_kv_heads, kv_len, head_dim // 2)``, dtype uint8.
        norms: Key vector norms, shape
            ``(batch, n_kv_heads, kv_len)``, dtype float32.
        centroids: Lloyd-Max centroid values, shape ``(n_levels,)``,
            dtype float32.
        scale: Attention scale factor (typically ``1 / sqrt(head_dim)``).
        n_q_heads: Number of query attention heads.
        n_kv_heads: Number of key-value attention heads.

    Returns:
        Attention scores, shape ``(batch, n_q_heads, q_len, kv_len)``.
    """
    batch, _, q_len, head_dim = q_rotated.shape
    _, _, kv_len, _ = packed_indices.shape

    ki_flat = packed_indices.reshape(
        batch * n_kv_heads, kv_len, head_dim // 2
    ).contiguous()
    kn_flat = norms.reshape(batch * n_kv_heads, kv_len).contiguous()
    centroids = centroids.contiguous().float()

    # For q_len > 1 (prefill), process each query position separately
    # to keep the GQA head mapping correct. The kernel maps
    # q_head → kv_head via gqa_ratio = n_q_heads // n_kv_heads.
    # Flattening q_len into heads would break this mapping.
    results = []
    for q_pos in range(q_len):
        q_slice = (
            q_rotated[:, :, q_pos : q_pos + 1, :]
            .reshape(batch * n_q_heads, head_dim)
            .contiguous()
        )

        out = torch.empty(
            batch * n_q_heads,
            kv_len,
            device=q_rotated.device,
            dtype=torch.float32,
        )

        grid = (batch * n_q_heads, triton.cdiv(kv_len, 64))

        _fused_qk_nibble_kernel[grid](
            q_slice,
            ki_flat,
            kn_flat,
            centroids,
            out,
            kv_len,
            head_dim,
            n_q_heads,
            n_kv_heads,
            scale,
            q_slice.stride(0),
            q_slice.stride(1),
            ki_flat.stride(0),
            ki_flat.stride(1),
            ki_flat.stride(2),
            kn_flat.stride(0),
            kn_flat.stride(1),
            out.stride(0),
            out.stride(1),
        )

        results.append(out.reshape(batch, n_q_heads, 1, kv_len))

    return torch.cat(results, dim=2)