Skip to content

flash_attention

turboquant_vllm.triton.flash_attention

Triton Flash Attention v2 -- forward-only kernel with GQA support.

Phase 1 of the fused TQ4 Flash Attention roadmap (P5). This vanilla kernel matches SDPA output and serves as the scaffold for injecting TQ4 decompression at K/V tile load points in Phase 2.

Supports
  • Grouped-Query Attention (GQA) with arbitrary Q/KV head ratios
  • Causal and non-causal modes
  • Optional additive attention mask (HF-compatible)
  • fp32 online softmax accumulation for numerical stability
  • RTX 4090 (SM89) and AMD ROCm via Triton HIP backend
Algorithm

Implements the online softmax from FlashAttention-2 (Dao 2023). Three fp32 state variables per query row -- running max m_i, running softmax denominator l_i, and output accumulator acc -- are maintained across K/V tile iterations. The correction factor alpha = exp2(m_old - m_new) rescales prior accumulated work when the running maximum increases. This is mathematically exact, not approximate.

Attributes:

Name Type Description
triton_flash_attention Tensor

Python wrapper that launches the Triton kernel with autotuned block sizes.

Examples:

from turboquant_vllm.triton.flash_attention import triton_flash_attention

out = triton_flash_attention(q, k, v)  # non-causal
out = triton_flash_attention(q, k, v, is_causal=True)  # prefill
See Also

:mod:turboquant_vllm.triton.attention_interface: HuggingFace AttentionInterface registration.

Functions

triton_flash_attention

triton_flash_attention(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    sm_scale: float | None = None,
    is_causal: bool = False,
    attention_mask: Tensor | None = None,
) -> Tensor

Compute scaled dot-product attention using Triton Flash Attention.

Parameters:

Name Type Description Default
q Tensor

Query tensor [batch, num_q_heads, seq_q, head_dim].

required
k Tensor

Key tensor [batch, num_kv_heads, seq_kv, head_dim].

required
v Tensor

Value tensor [batch, num_kv_heads, seq_kv, head_dim].

required
sm_scale float | None

Softmax scale factor. Defaults to 1 / sqrt(head_dim).

None
is_causal bool

Apply causal masking. Only valid when seq_q >= seq_kv (prefill). For decode (seq_q == 1), forced to False.

False
attention_mask Tensor | None

Optional additive mask [batch, 1|heads, seq_q, seq_kv]. Values of 0 mean attend, large negative values mean block. Mutually exclusive with is_causal in practice (HF sets is_causal only when attention_mask is None).

None

Returns:

Type Description
Tensor

Attention output [batch, num_q_heads, seq_q, head_dim].

Source code in src/turboquant_vllm/triton/flash_attention.py
def triton_flash_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    sm_scale: float | None = None,
    is_causal: bool = False,
    attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
    """Compute scaled dot-product attention using Triton Flash Attention.

    Args:
        q: Query tensor ``[batch, num_q_heads, seq_q, head_dim]``.
        k: Key tensor ``[batch, num_kv_heads, seq_kv, head_dim]``.
        v: Value tensor ``[batch, num_kv_heads, seq_kv, head_dim]``.
        sm_scale: Softmax scale factor. Defaults to ``1 / sqrt(head_dim)``.
        is_causal: Apply causal masking. Only valid when ``seq_q >= seq_kv``
            (prefill). For decode (``seq_q == 1``), forced to ``False``.
        attention_mask: Optional additive mask
            ``[batch, 1|heads, seq_q, seq_kv]``. Values of 0 mean attend,
            large negative values mean block. Mutually exclusive with
            ``is_causal`` in practice (HF sets ``is_causal`` only when
            ``attention_mask is None``).

    Returns:
        Attention output ``[batch, num_q_heads, seq_q, head_dim]``.
    """
    B, H_Q, N_Q, D = q.shape
    _, H_KV, N_KV, _ = k.shape

    assert q.dtype == k.dtype == v.dtype, "Q, K, V must have the same dtype"
    assert q.dtype in (
        torch.float16,
        torch.bfloat16,
    ), f"Triton FA requires fp16 or bf16, got {q.dtype}"
    assert H_Q % H_KV == 0, f"Q heads ({H_Q}) must be divisible by KV heads ({H_KV})"
    assert k.shape[2] == v.shape[2], "K and V must have the same sequence length"
    assert q.shape[3] == k.shape[3] == v.shape[3], "Head dimensions must match"

    # Single query token never needs causal masking
    if is_causal and N_Q == 1:
        is_causal = False

    if sm_scale is None:
        sm_scale = 1.0 / math.sqrt(D)

    out = torch.empty_like(q)

    # Handle optional mask
    has_mask = attention_mask is not None
    if has_mask:
        assert attention_mask is not None  # for type narrowing
        mask = attention_mask
        stride_mz, stride_mh, stride_mm, stride_mn = mask.stride()
        # Handle broadcast: size-1 dims have nonzero stride in PyTorch
        # but must be treated as stride=0 for correct broadcast behavior.
        if mask.shape[0] == 1:
            stride_mz = 0
        if mask.shape[1] == 1:
            stride_mh = 0
    else:
        mask = q  # dummy pointer, never dereferenced
        stride_mz = stride_mh = stride_mm = stride_mn = 0

    # Grid: one CTA per (Q-block, batch*q_head)
    def grid(META: dict) -> tuple[int, int]:
        """Compute kernel launch grid from autotuned block size.

        Returns:
            ``(num_q_blocks, batch * num_q_heads)`` grid dimensions.
        """
        return (triton.cdiv(N_Q, META["BLOCK_M"]), B * H_Q)

    _fwd_kernel[grid](
        q,
        k,
        v,
        out,
        mask,
        sm_scale,
        *q.stride(),
        *k.stride(),
        *v.stride(),
        *out.stride(),
        stride_mz,
        stride_mh,
        stride_mm,
        stride_mn,
        H_Q,
        H_KV,
        N_Q,
        N_KV,
        HEAD_DIM=D,
        IS_CAUSAL=is_causal,
        HAS_MASK=has_mask,
    )

    return out