Skip to content

flash_attention_tq4_kv

turboquant_vllm.triton.flash_attention_tq4_kv

Fused TQ4 Flash Attention -- both K and V decompressed in the inner loop.

Phase 3 of the P5 roadmap. Both K and V tiles are decompressed inline from nibble-packed uint8 indices. The query is pre-rotated by Pi^T and the output is post-rotated by Pi outside the kernel (since K and V share the same rotation matrix).

Attributes:

Name Type Description
triton_flash_attention_tq4_kv Tensor

Python wrapper that pre-rotates Q, launches the fused kernel, and post-rotates the output.

Examples:

from turboquant_vllm.triton.flash_attention_tq4_kv import (
    triton_flash_attention_tq4_kv,
)

out = triton_flash_attention_tq4_kv(
    q,
    k_packed,
    k_norms,
    v_packed,
    v_norms,
    centroids,
    rotation,
    sm_scale=None,
)
See Also

:mod:turboquant_vllm.triton.flash_attention_tq4: Phase 2 (K-only).

Functions

triton_flash_attention_tq4_kv

triton_flash_attention_tq4_kv(
    q: Tensor,
    k_packed: Tensor,
    k_norms: Tensor,
    v_packed: Tensor,
    v_norms: Tensor,
    centroids: Tensor,
    rotation: Tensor,
    sm_scale: float | None = None,
    is_causal: bool = False,
) -> Tensor

Fused TQ4 Flash Attention with both K and V compressed.

Pre-rotates Q by rotation^T, launches the kernel that decompresses both K and V inline, then post-rotates the output by rotation to return to the original coordinate space.

Parameters:

Name Type Description Default
q Tensor

Query [batch, H_Q, seq_q, head_dim] fp16/bf16.

required
k_packed Tensor

Nibble-packed key indices [batch, H_KV, seq_kv, D//2] uint8.

required
k_norms Tensor

Key norms [batch, H_KV, seq_kv] or [..., 1] fp32.

required
v_packed Tensor

Nibble-packed value indices [batch, H_KV, seq_kv, D//2] uint8.

required
v_norms Tensor

Value norms [batch, H_KV, seq_kv] or [..., 1] fp32.

required
centroids Tensor

Shared Lloyd-Max codebook [16] fp32.

required
rotation Tensor

Shared orthogonal rotation [head_dim, head_dim] fp32.

required
sm_scale float | None

Softmax scale. Defaults to 1 / sqrt(head_dim).

None
is_causal bool

Apply causal masking.

False

Returns:

Type Description
Tensor

Attention output [batch, H_Q, seq_q, head_dim] in original space.

Source code in src/turboquant_vllm/triton/flash_attention_tq4_kv.py
def triton_flash_attention_tq4_kv(
    q: torch.Tensor,
    k_packed: torch.Tensor,
    k_norms: torch.Tensor,
    v_packed: torch.Tensor,
    v_norms: torch.Tensor,
    centroids: torch.Tensor,
    rotation: torch.Tensor,
    sm_scale: float | None = None,
    is_causal: bool = False,
) -> torch.Tensor:
    """Fused TQ4 Flash Attention with both K and V compressed.

    Pre-rotates Q by ``rotation^T``, launches the kernel that decompresses
    both K and V inline, then post-rotates the output by ``rotation`` to
    return to the original coordinate space.

    Args:
        q: Query ``[batch, H_Q, seq_q, head_dim]`` fp16/bf16.
        k_packed: Nibble-packed key indices ``[batch, H_KV, seq_kv, D//2]`` uint8.
        k_norms: Key norms ``[batch, H_KV, seq_kv]`` or ``[..., 1]`` fp32.
        v_packed: Nibble-packed value indices ``[batch, H_KV, seq_kv, D//2]`` uint8.
        v_norms: Value norms ``[batch, H_KV, seq_kv]`` or ``[..., 1]`` fp32.
        centroids: Shared Lloyd-Max codebook ``[16]`` fp32.
        rotation: Shared orthogonal rotation ``[head_dim, head_dim]`` fp32.
        sm_scale: Softmax scale. Defaults to ``1 / sqrt(head_dim)``.
        is_causal: Apply causal masking.

    Returns:
        Attention output ``[batch, H_Q, seq_q, head_dim]`` in original space.
    """
    B, H_Q, N_Q, D = q.shape
    _, H_KV, N_KV, HALF_D = k_packed.shape

    assert HALF_D == D // 2
    assert H_Q % H_KV == 0
    assert k_packed.dtype == torch.uint8
    assert v_packed.dtype == torch.uint8
    assert k_norms.dtype == torch.float32
    assert v_norms.dtype == torch.float32

    # Squeeze trailing 1 from norms
    if k_norms.dim() == 4 and k_norms.shape[-1] == 1:
        k_norms = k_norms.squeeze(-1)
    if v_norms.dim() == 4 and v_norms.shape[-1] == 1:
        v_norms = v_norms.squeeze(-1)

    if is_causal and N_Q == 1:
        is_causal = False

    if sm_scale is None:
        sm_scale = 1.0 / math.sqrt(D)

    # Pre-rotate Q
    q_rot = torch.matmul(q.float(), rotation.T).to(q.dtype)

    out_rot = torch.empty_like(q)

    def grid(META: dict) -> tuple[int, int]:
        """Compute launch grid.

        Returns:
            ``(num_q_blocks, batch * H_Q)`` grid dimensions.
        """
        return (triton.cdiv(N_Q, META["BLOCK_M"]), B * H_Q)

    _fwd_tq4_kv_kernel[grid](
        q_rot,
        k_packed,
        k_norms,
        v_packed,
        v_norms,
        centroids,
        out_rot,
        sm_scale,
        *q_rot.stride(),
        *k_packed.stride(),
        *k_norms.stride(),
        *v_packed.stride(),
        *v_norms.stride(),
        *out_rot.stride(),
        H_Q,
        H_KV,
        N_Q,
        N_KV,
        HEAD_DIM=D,
        IS_CAUSAL=is_causal,
    )

    # Post-rotate: convert from rotated space back to original space
    return torch.matmul(out_rot.float(), rotation).to(q.dtype)