Skip to content

flash_attention_tq4

turboquant_vllm.triton.flash_attention_tq4

Fused TQ4 Flash Attention -- K decompression inside the FA inner loop.

Phase 2 of the P5 roadmap. Replaces the standard K tile load with: nibble unpack -> centroid gather -> interleave -> norm scale. The query is pre-rotated by Pi^T outside the kernel. Values remain standard fp16/bf16.

The fp32 online softmax (m_i, l_i, acc) state machine prevents the 0.023/layer cosine drift that killed the Q@K^T-only kernel (Key Lesson #7).

Attributes:

Name Type Description
triton_flash_attention_tq4 Tensor

Python wrapper that pre-rotates Q and launches the fused TQ4 kernel.

Examples:

from turboquant_vllm.triton.flash_attention_tq4 import (
    triton_flash_attention_tq4,
)

out = triton_flash_attention_tq4(
    q,
    k_packed,
    k_norms,
    centroids,
    rotation_matrix,
    v,
)
See Also

:mod:turboquant_vllm.triton.flash_attention: Phase 1 vanilla kernel. :mod:turboquant_vllm.quantizer: TurboQuantMSE rotation + quantization.

Functions

triton_flash_attention_tq4

triton_flash_attention_tq4(
    q: Tensor,
    k_packed: Tensor,
    k_norms: Tensor,
    centroids: Tensor,
    rotation: Tensor,
    v: Tensor,
    sm_scale: float | None = None,
    is_causal: bool = False,
) -> Tensor

Fused TQ4 Flash Attention with compressed K and standard V.

Pre-rotates Q by rotation^T, then launches the fused kernel that decompresses nibble-packed K indices inline via centroid gather.

Parameters:

Name Type Description Default
q Tensor

Query [batch, H_Q, seq_q, head_dim] fp16/bf16.

required
k_packed Tensor

Nibble-packed key indices [batch, H_KV, seq_kv, head_dim//2] uint8.

required
k_norms Tensor

Key norms [batch, H_KV, seq_kv] or [..., 1] fp32.

required
centroids Tensor

Lloyd-Max codebook [16] fp32 (for 4-bit).

required
rotation Tensor

Orthogonal rotation matrix [head_dim, head_dim] fp32.

required
v Tensor

Values [batch, H_KV, seq_kv, head_dim] fp16/bf16.

required
sm_scale float | None

Softmax scale. Defaults to 1 / sqrt(head_dim).

None
is_causal bool

Apply causal masking.

False

Returns:

Type Description
Tensor

Attention output [batch, H_Q, seq_q, head_dim].

Source code in src/turboquant_vllm/triton/flash_attention_tq4.py
def triton_flash_attention_tq4(
    q: torch.Tensor,
    k_packed: torch.Tensor,
    k_norms: torch.Tensor,
    centroids: torch.Tensor,
    rotation: torch.Tensor,
    v: torch.Tensor,
    sm_scale: float | None = None,
    is_causal: bool = False,
) -> torch.Tensor:
    """Fused TQ4 Flash Attention with compressed K and standard V.

    Pre-rotates Q by ``rotation^T``, then launches the fused kernel that
    decompresses nibble-packed K indices inline via centroid gather.

    Args:
        q: Query ``[batch, H_Q, seq_q, head_dim]`` fp16/bf16.
        k_packed: Nibble-packed key indices ``[batch, H_KV, seq_kv, head_dim//2]`` uint8.
        k_norms: Key norms ``[batch, H_KV, seq_kv]`` or ``[..., 1]`` fp32.
        centroids: Lloyd-Max codebook ``[16]`` fp32 (for 4-bit).
        rotation: Orthogonal rotation matrix ``[head_dim, head_dim]`` fp32.
        v: Values ``[batch, H_KV, seq_kv, head_dim]`` fp16/bf16.
        sm_scale: Softmax scale. Defaults to ``1 / sqrt(head_dim)``.
        is_causal: Apply causal masking.

    Returns:
        Attention output ``[batch, H_Q, seq_q, head_dim]``.
    """
    B, H_Q, N_Q, D = q.shape
    _, H_KV, N_KV, HALF_D = k_packed.shape

    assert HALF_D == D // 2, f"Packed dim {HALF_D} != head_dim//2 ({D // 2})"
    assert H_Q % H_KV == 0, f"Q heads ({H_Q}) must be divisible by KV heads ({H_KV})"
    assert k_packed.dtype == torch.uint8, "k_packed must be uint8"
    assert k_norms.dtype == torch.float32, "k_norms must be float32"
    assert centroids.dtype == torch.float32, "centroids must be float32"

    # Squeeze trailing 1 from norms if present
    if k_norms.dim() == 4 and k_norms.shape[-1] == 1:
        k_norms = k_norms.squeeze(-1)

    if is_causal and N_Q == 1:
        is_causal = False

    if sm_scale is None:
        sm_scale = 1.0 / math.sqrt(D)

    # Pre-rotate Q: q_rot = q @ Pi^T
    q_rot = torch.matmul(q.float(), rotation.T).to(q.dtype)

    out = torch.empty_like(q)

    def grid(META: dict) -> tuple[int, int]:
        """Compute launch grid from autotuned block size.

        Returns:
            ``(num_q_blocks, batch * H_Q)`` grid dimensions.
        """
        return (triton.cdiv(N_Q, META["BLOCK_M"]), B * H_Q)

    _fwd_tq4_kernel[grid](
        q_rot,
        k_packed,
        k_norms,
        centroids,
        v,
        out,
        sm_scale,
        *q_rot.stride(),
        *k_packed.stride(),
        *k_norms.stride(),
        *v.stride(),
        *out.stride(),
        H_Q,
        H_KV,
        N_Q,
        N_KV,
        HEAD_DIM=D,
        IS_CAUSAL=is_causal,
    )

    return out