Skip to content

triton

turboquant_vllm.triton

Fused Triton kernels for TurboQuant compressed attention.

Phase 1 (P5): Vanilla Flash Attention kernel with GQA support. Phase 2 (P5): Fused TQ4 K decompression inside the FA inner loop. Phase 3 (P5): Fused TQ4 K+V decompression with post-rotation. Phase 3c.8: Standalone TQ4 cache decompress kernel for vLLM backend. Phase 3c.9: Standalone TQ4 compress kernel for vLLM backend.

Legacy: Q@K^T-only fused kernel (superseded -- see Key Lesson #7).

Attributes:

Name Type Description
triton_flash_attention Tensor

Vanilla FA forward with online softmax.

triton_flash_attention_tq4 Tensor

Fused TQ4 FA with compressed K tiles.

triton_flash_attention_tq4_kv Tensor

Fused TQ4 FA with compressed K+V tiles.

triton_fa_forward tuple[Tensor, None]

HF AttentionInterface-compatible wrapper.

register_triton_fa None

Register the triton_fa backend globally.

install_triton_fa None

Register and activate vanilla FA on a model.

install_fused_tq4_kv None

Activate fused TQ4 K+V with cache side-channel.

uninstall_fused_tq4_kv None

Remove fused attention and restore SDPA.

tq4_compress None

Fused TQ4 compress (norm+rotate+quantize+pack).

tq4_decompress None

Fused TQ4 decompress (unpack+gather+scale).

fused_qk_scores Tensor

Legacy Q@K^T-only kernel (kept for reference).

Examples:

Direct kernel usage:

from turboquant_vllm.triton import triton_flash_attention

out = triton_flash_attention(q, k, v)

HuggingFace integration:

from turboquant_vllm.triton import install_triton_fa

install_triton_fa(model)
output = model.generate(...)
See Also

:mod:turboquant_vllm.kv_cache: CompressedDynamicCache storage layer.

Functions

install_fused_tq4_kv

install_fused_tq4_kv(model: Module, cache: CompressedDynamicCache) -> None

Activate fused TQ4 K+V attention on model with cache side-channel.

Registers the triton_fa_tq4_kv backend, stashes cache on each attention layer as module._tq4_cache, sets the model's _attn_implementation, and enables fused_mode on the cache to skip wasted decompression (P5b optimization).

Parameters:

Name Type Description Default
model Module

HuggingFace model with attention layers that have layer_idx.

required
cache CompressedDynamicCache

CompressedDynamicCache instance that stores compressed K/V.

required

Raises:

Type Description
AttributeError

If model has no config attribute.

Source code in src/turboquant_vllm/triton/attention_interface.py
def install_fused_tq4_kv(model: torch.nn.Module, cache: CompressedDynamicCache) -> None:
    """Activate fused TQ4 K+V attention on *model* with cache side-channel.

    Registers the ``triton_fa_tq4_kv`` backend, stashes *cache* on each
    attention layer as ``module._tq4_cache``, sets the model's
    ``_attn_implementation``, and enables ``fused_mode`` on the cache
    to skip wasted decompression (P5b optimization).

    Args:
        model: HuggingFace model with attention layers that have ``layer_idx``.
        cache: CompressedDynamicCache instance that stores compressed K/V.

    Raises:
        AttributeError: If *model* has no ``config`` attribute.
    """
    ALL_ATTENTION_FUNCTIONS.register("triton_fa_tq4_kv", triton_fa_tq4_kv_forward)

    config = getattr(model, "config", None)
    if config is None:
        msg = "Model has no config attribute"
        raise AttributeError(msg)
    config._attn_implementation = "triton_fa_tq4_kv"

    # Enable fused mode: skip decompression in cache.update()
    cache.fused_mode = True

    # Stash cache reference on each attention layer
    for module in model.modules():
        if hasattr(module, "layer_idx"):
            object.__setattr__(module, "_tq4_cache", cache)

install_triton_fa

install_triton_fa(model: Module) -> None

Register the backend and activate it on model.

Changes model.config._attn_implementation to "triton_fa". The model resolves the attention function at forward time, so this takes effect on the next forward call.

Parameters:

Name Type Description Default
model Module

A HuggingFace model with a config attribute.

required

Raises:

Type Description
AttributeError

If model has no config attribute.

Source code in src/turboquant_vllm/triton/attention_interface.py
def install_triton_fa(model: torch.nn.Module) -> None:
    """Register the backend and activate it on *model*.

    Changes ``model.config._attn_implementation`` to ``"triton_fa"``.
    The model resolves the attention function at forward time, so this
    takes effect on the next forward call.

    Args:
        model: A HuggingFace model with a ``config`` attribute.

    Raises:
        AttributeError: If *model* has no ``config`` attribute.
    """
    register_triton_fa()
    config = getattr(model, "config", None)
    if config is None:
        msg = "Model has no config attribute"
        raise AttributeError(msg)
    config._attn_implementation = "triton_fa"

register_triton_fa

register_triton_fa() -> None

Register triton_fa as a global attention backend in HuggingFace.

Safe to call multiple times -- overwrites the previous registration.

Source code in src/turboquant_vllm/triton/attention_interface.py
def register_triton_fa() -> None:
    """Register ``triton_fa`` as a global attention backend in HuggingFace.

    Safe to call multiple times -- overwrites the previous registration.
    """
    ALL_ATTENTION_FUNCTIONS.register("triton_fa", triton_fa_forward)

triton_fa_forward

triton_fa_forward(
    module: Module,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    attention_mask: Optional[Tensor],
    dropout: float = 0.0,
    scaling: Optional[float] = None,
    **kwargs: object,
) -> tuple[Tensor, None]

HF-compatible attention forward using Triton Flash Attention.

Signature matches transformers.integrations.sdpa_attention.sdpa_attention_forward. Handles GQA natively (no KV repeat expansion needed).

Parameters:

Name Type Description Default
module Module

The attention layer module. Used to read is_causal attribute.

required
query Tensor

[batch, num_q_heads, seq_q, head_dim].

required
key Tensor

[batch, num_kv_heads, seq_kv, head_dim].

required
value Tensor

[batch, num_kv_heads, seq_kv, head_dim].

required
attention_mask Optional[Tensor]

Optional additive mask [batch, 1|heads, seq_q, seq_kv].

required
dropout float

Dropout rate (must be 0 -- inference only).

0.0
scaling Optional[float]

Softmax scale. Defaults to 1 / sqrt(head_dim).

None

Other Parameters:

Name Type Description
is_causal bool | None

Override causal mode. If None, inferred from query.shape[2] and module.is_causal.

**kwargs object

Additional model-specific arguments (ignored).

Returns:

Type Description
Tensor

(output, None) where output is [batch, seq_q, num_q_heads, head_dim]

None

(transposed to match HF convention).

Source code in src/turboquant_vllm/triton/attention_interface.py
def triton_fa_forward(
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    dropout: float = 0.0,
    scaling: Optional[float] = None,
    **kwargs: object,
) -> tuple[torch.Tensor, None]:
    """HF-compatible attention forward using Triton Flash Attention.

    Signature matches ``transformers.integrations.sdpa_attention.sdpa_attention_forward``.
    Handles GQA natively (no KV repeat expansion needed).

    Args:
        module: The attention layer module. Used to read ``is_causal`` attribute.
        query: ``[batch, num_q_heads, seq_q, head_dim]``.
        key: ``[batch, num_kv_heads, seq_kv, head_dim]``.
        value: ``[batch, num_kv_heads, seq_kv, head_dim]``.
        attention_mask: Optional additive mask ``[batch, 1|heads, seq_q, seq_kv]``.
        dropout: Dropout rate (must be 0 -- inference only).
        scaling: Softmax scale. Defaults to ``1 / sqrt(head_dim)``.

    Other Parameters:
        is_causal (bool | None): Override causal mode. If ``None``, inferred
            from ``query.shape[2]`` and ``module.is_causal``.
        **kwargs: Additional model-specific arguments (ignored).

    Returns:
        ``(output, None)`` where output is ``[batch, seq_q, num_q_heads, head_dim]``
        (transposed to match HF convention).
    """
    # Determine causal mode (same logic as HF SDPA backend)
    is_causal_raw = kwargs.pop("is_causal", None)
    is_causal_flag: bool
    if is_causal_raw is None:
        is_causal_flag = bool(
            query.shape[2] > 1
            and attention_mask is None
            and getattr(module, "is_causal", True)
        )
    else:
        is_causal_flag = bool(is_causal_raw)

    out = triton_flash_attention(
        query,
        key,
        value,
        sm_scale=scaling,
        is_causal=is_causal_flag,
        attention_mask=attention_mask if not is_causal_flag else None,
    )

    # Transpose to [batch, seq, heads, head_dim] per HF convention
    out = out.transpose(1, 2).contiguous()
    return out, None

triton_fa_tq4_kv_forward

triton_fa_tq4_kv_forward(
    module: Module,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    attention_mask: Optional[Tensor],
    dropout: float = 0.0,
    scaling: Optional[float] = None,
    **kwargs: object,
) -> tuple[Tensor, None]

Fused TQ4 K+V attention via cache side-channel.

Reads compressed K/V from the CompressedDynamicCache stashed on module._tq4_cache (ignoring the decompressed key/value args). Falls back to vanilla Triton FA if no cache reference is found.

Parameters:

Name Type Description Default
module Module

Attention layer with layer_idx and _tq4_cache attrs.

required
query Tensor

[batch, H_Q, seq_q, head_dim] (RoPE already applied).

required
key Tensor

[batch, H_KV, seq_kv, head_dim] — ignored when fused.

required
value Tensor

[batch, H_KV, seq_kv, head_dim] — ignored when fused.

required
attention_mask Optional[Tensor]

Optional additive mask.

required
dropout float

Must be 0 (inference only).

0.0
scaling Optional[float]

Softmax scale.

None

Other Parameters:

Name Type Description
is_causal bool | None

Override causal mode.

**kwargs object

Additional model-specific arguments (ignored).

Returns:

Type Description
tuple[Tensor, None]

(output, None) with output [batch, seq_q, H_Q, head_dim].

Source code in src/turboquant_vllm/triton/attention_interface.py
def triton_fa_tq4_kv_forward(
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    dropout: float = 0.0,
    scaling: Optional[float] = None,
    **kwargs: object,
) -> tuple[torch.Tensor, None]:
    """Fused TQ4 K+V attention via cache side-channel.

    Reads compressed K/V from the ``CompressedDynamicCache`` stashed on
    ``module._tq4_cache`` (ignoring the decompressed *key*/*value* args).
    Falls back to vanilla Triton FA if no cache reference is found.

    Args:
        module: Attention layer with ``layer_idx`` and ``_tq4_cache`` attrs.
        query: ``[batch, H_Q, seq_q, head_dim]`` (RoPE already applied).
        key: ``[batch, H_KV, seq_kv, head_dim]`` — ignored when fused.
        value: ``[batch, H_KV, seq_kv, head_dim]`` — ignored when fused.
        attention_mask: Optional additive mask.
        dropout: Must be 0 (inference only).
        scaling: Softmax scale.

    Other Parameters:
        is_causal (bool | None): Override causal mode.
        **kwargs: Additional model-specific arguments (ignored).

    Returns:
        ``(output, None)`` with output ``[batch, seq_q, H_Q, head_dim]``.
    """
    cache: CompressedDynamicCache | None = getattr(module, "_tq4_cache", None)
    layer_idx: int | None = getattr(module, "layer_idx", None)

    # Fallback: no cache stash → use vanilla Triton FA with decompressed K/V
    if cache is None or layer_idx is None:
        return triton_fa_forward(
            module, query, key, value, attention_mask, dropout, scaling, **kwargs
        )

    # Read compressed K/V from cache side-channel
    k_packed, k_norms, v_packed, v_norms = cache.get_compressed(layer_idx)
    rotation = cache.rotation.to(device=query.device)
    centroids = cache.centroids.to(device=query.device)

    # Determine causal mode
    is_causal_raw = kwargs.pop("is_causal", None)
    is_causal_flag: bool
    if is_causal_raw is None:
        is_causal_flag = bool(
            query.shape[2] > 1
            and attention_mask is None
            and getattr(module, "is_causal", True)
        )
    else:
        is_causal_flag = bool(is_causal_raw)

    out = triton_flash_attention_tq4_kv(
        query,
        k_packed,
        k_norms,
        v_packed,
        v_norms,
        centroids,
        rotation,
        sm_scale=scaling,
        is_causal=is_causal_flag,
    )

    out = out.transpose(1, 2).contiguous()
    return out, None

uninstall_fused_tq4_kv

uninstall_fused_tq4_kv(model: Module) -> None

Remove fused TQ4 attention and restore SDPA.

Removes _tq4_cache from attention layers, disables fused_mode on the cache, and resets _attn_implementation to "sdpa".

Parameters:

Name Type Description Default
model Module

Model previously configured with install_fused_tq4_kv.

required
Source code in src/turboquant_vllm/triton/attention_interface.py
def uninstall_fused_tq4_kv(model: torch.nn.Module) -> None:
    """Remove fused TQ4 attention and restore SDPA.

    Removes ``_tq4_cache`` from attention layers, disables ``fused_mode``
    on the cache, and resets ``_attn_implementation`` to ``"sdpa"``.

    Args:
        model: Model previously configured with ``install_fused_tq4_kv``.
    """
    config = getattr(model, "config", None)
    if config is not None:
        config._attn_implementation = "sdpa"

    for module in model.modules():
        if hasattr(module, "_tq4_cache"):
            cache = getattr(module, "_tq4_cache", None)
            if cache is not None and hasattr(cache, "fused_mode"):
                cache.fused_mode = False  # type: ignore[union-attr]
            if hasattr(module, "_tq4_cache"):
                delattr(module, "_tq4_cache")

triton_flash_attention

triton_flash_attention(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    sm_scale: float | None = None,
    is_causal: bool = False,
    attention_mask: Tensor | None = None,
) -> Tensor

Compute scaled dot-product attention using Triton Flash Attention.

Parameters:

Name Type Description Default
q Tensor

Query tensor [batch, num_q_heads, seq_q, head_dim].

required
k Tensor

Key tensor [batch, num_kv_heads, seq_kv, head_dim].

required
v Tensor

Value tensor [batch, num_kv_heads, seq_kv, head_dim].

required
sm_scale float | None

Softmax scale factor. Defaults to 1 / sqrt(head_dim).

None
is_causal bool

Apply causal masking. Only valid when seq_q >= seq_kv (prefill). For decode (seq_q == 1), forced to False.

False
attention_mask Tensor | None

Optional additive mask [batch, 1|heads, seq_q, seq_kv]. Values of 0 mean attend, large negative values mean block. Mutually exclusive with is_causal in practice (HF sets is_causal only when attention_mask is None).

None

Returns:

Type Description
Tensor

Attention output [batch, num_q_heads, seq_q, head_dim].

Source code in src/turboquant_vllm/triton/flash_attention.py
def triton_flash_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    sm_scale: float | None = None,
    is_causal: bool = False,
    attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
    """Compute scaled dot-product attention using Triton Flash Attention.

    Args:
        q: Query tensor ``[batch, num_q_heads, seq_q, head_dim]``.
        k: Key tensor ``[batch, num_kv_heads, seq_kv, head_dim]``.
        v: Value tensor ``[batch, num_kv_heads, seq_kv, head_dim]``.
        sm_scale: Softmax scale factor. Defaults to ``1 / sqrt(head_dim)``.
        is_causal: Apply causal masking. Only valid when ``seq_q >= seq_kv``
            (prefill). For decode (``seq_q == 1``), forced to ``False``.
        attention_mask: Optional additive mask
            ``[batch, 1|heads, seq_q, seq_kv]``. Values of 0 mean attend,
            large negative values mean block. Mutually exclusive with
            ``is_causal`` in practice (HF sets ``is_causal`` only when
            ``attention_mask is None``).

    Returns:
        Attention output ``[batch, num_q_heads, seq_q, head_dim]``.
    """
    B, H_Q, N_Q, D = q.shape
    _, H_KV, N_KV, _ = k.shape

    assert q.dtype == k.dtype == v.dtype, "Q, K, V must have the same dtype"
    assert q.dtype in (
        torch.float16,
        torch.bfloat16,
    ), f"Triton FA requires fp16 or bf16, got {q.dtype}"
    assert H_Q % H_KV == 0, f"Q heads ({H_Q}) must be divisible by KV heads ({H_KV})"
    assert k.shape[2] == v.shape[2], "K and V must have the same sequence length"
    assert q.shape[3] == k.shape[3] == v.shape[3], "Head dimensions must match"

    # Single query token never needs causal masking
    if is_causal and N_Q == 1:
        is_causal = False

    if sm_scale is None:
        sm_scale = 1.0 / math.sqrt(D)

    out = torch.empty_like(q)

    # Handle optional mask
    has_mask = attention_mask is not None
    if has_mask:
        assert attention_mask is not None  # for type narrowing
        mask = attention_mask
        stride_mz, stride_mh, stride_mm, stride_mn = mask.stride()
        # Handle broadcast: size-1 dims have nonzero stride in PyTorch
        # but must be treated as stride=0 for correct broadcast behavior.
        if mask.shape[0] == 1:
            stride_mz = 0
        if mask.shape[1] == 1:
            stride_mh = 0
    else:
        mask = q  # dummy pointer, never dereferenced
        stride_mz = stride_mh = stride_mm = stride_mn = 0

    # Grid: one CTA per (Q-block, batch*q_head)
    def grid(META: dict) -> tuple[int, int]:
        """Compute kernel launch grid from autotuned block size.

        Returns:
            ``(num_q_blocks, batch * num_q_heads)`` grid dimensions.
        """
        return (triton.cdiv(N_Q, META["BLOCK_M"]), B * H_Q)

    _fwd_kernel[grid](
        q,
        k,
        v,
        out,
        mask,
        sm_scale,
        *q.stride(),
        *k.stride(),
        *v.stride(),
        *out.stride(),
        stride_mz,
        stride_mh,
        stride_mm,
        stride_mn,
        H_Q,
        H_KV,
        N_Q,
        N_KV,
        HEAD_DIM=D,
        IS_CAUSAL=is_causal,
        HAS_MASK=has_mask,
    )

    return out

triton_flash_attention_tq4

triton_flash_attention_tq4(
    q: Tensor,
    k_packed: Tensor,
    k_norms: Tensor,
    centroids: Tensor,
    rotation: Tensor,
    v: Tensor,
    sm_scale: float | None = None,
    is_causal: bool = False,
) -> Tensor

Fused TQ4 Flash Attention with compressed K and standard V.

Pre-rotates Q by rotation^T, then launches the fused kernel that decompresses nibble-packed K indices inline via centroid gather.

Parameters:

Name Type Description Default
q Tensor

Query [batch, H_Q, seq_q, head_dim] fp16/bf16.

required
k_packed Tensor

Nibble-packed key indices [batch, H_KV, seq_kv, head_dim//2] uint8.

required
k_norms Tensor

Key norms [batch, H_KV, seq_kv] or [..., 1] fp32.

required
centroids Tensor

Lloyd-Max codebook [16] fp32 (for 4-bit).

required
rotation Tensor

Orthogonal rotation matrix [head_dim, head_dim] fp32.

required
v Tensor

Values [batch, H_KV, seq_kv, head_dim] fp16/bf16.

required
sm_scale float | None

Softmax scale. Defaults to 1 / sqrt(head_dim).

None
is_causal bool

Apply causal masking.

False

Returns:

Type Description
Tensor

Attention output [batch, H_Q, seq_q, head_dim].

Source code in src/turboquant_vllm/triton/flash_attention_tq4.py
def triton_flash_attention_tq4(
    q: torch.Tensor,
    k_packed: torch.Tensor,
    k_norms: torch.Tensor,
    centroids: torch.Tensor,
    rotation: torch.Tensor,
    v: torch.Tensor,
    sm_scale: float | None = None,
    is_causal: bool = False,
) -> torch.Tensor:
    """Fused TQ4 Flash Attention with compressed K and standard V.

    Pre-rotates Q by ``rotation^T``, then launches the fused kernel that
    decompresses nibble-packed K indices inline via centroid gather.

    Args:
        q: Query ``[batch, H_Q, seq_q, head_dim]`` fp16/bf16.
        k_packed: Nibble-packed key indices ``[batch, H_KV, seq_kv, head_dim//2]`` uint8.
        k_norms: Key norms ``[batch, H_KV, seq_kv]`` or ``[..., 1]`` fp32.
        centroids: Lloyd-Max codebook ``[16]`` fp32 (for 4-bit).
        rotation: Orthogonal rotation matrix ``[head_dim, head_dim]`` fp32.
        v: Values ``[batch, H_KV, seq_kv, head_dim]`` fp16/bf16.
        sm_scale: Softmax scale. Defaults to ``1 / sqrt(head_dim)``.
        is_causal: Apply causal masking.

    Returns:
        Attention output ``[batch, H_Q, seq_q, head_dim]``.
    """
    B, H_Q, N_Q, D = q.shape
    _, H_KV, N_KV, HALF_D = k_packed.shape

    assert HALF_D == D // 2, f"Packed dim {HALF_D} != head_dim//2 ({D // 2})"
    assert H_Q % H_KV == 0, f"Q heads ({H_Q}) must be divisible by KV heads ({H_KV})"
    assert k_packed.dtype == torch.uint8, "k_packed must be uint8"
    assert k_norms.dtype == torch.float32, "k_norms must be float32"
    assert centroids.dtype == torch.float32, "centroids must be float32"

    # Squeeze trailing 1 from norms if present
    if k_norms.dim() == 4 and k_norms.shape[-1] == 1:
        k_norms = k_norms.squeeze(-1)

    if is_causal and N_Q == 1:
        is_causal = False

    if sm_scale is None:
        sm_scale = 1.0 / math.sqrt(D)

    # Pre-rotate Q: q_rot = q @ Pi^T
    q_rot = torch.matmul(q.float(), rotation.T).to(q.dtype)

    out = torch.empty_like(q)

    def grid(META: dict) -> tuple[int, int]:
        """Compute launch grid from autotuned block size.

        Returns:
            ``(num_q_blocks, batch * H_Q)`` grid dimensions.
        """
        return (triton.cdiv(N_Q, META["BLOCK_M"]), B * H_Q)

    _fwd_tq4_kernel[grid](
        q_rot,
        k_packed,
        k_norms,
        centroids,
        v,
        out,
        sm_scale,
        *q_rot.stride(),
        *k_packed.stride(),
        *k_norms.stride(),
        *v.stride(),
        *out.stride(),
        H_Q,
        H_KV,
        N_Q,
        N_KV,
        HEAD_DIM=D,
        IS_CAUSAL=is_causal,
    )

    return out

triton_flash_attention_tq4_kv

triton_flash_attention_tq4_kv(
    q: Tensor,
    k_packed: Tensor,
    k_norms: Tensor,
    v_packed: Tensor,
    v_norms: Tensor,
    centroids: Tensor,
    rotation: Tensor,
    sm_scale: float | None = None,
    is_causal: bool = False,
) -> Tensor

Fused TQ4 Flash Attention with both K and V compressed.

Pre-rotates Q by rotation^T, launches the kernel that decompresses both K and V inline, then post-rotates the output by rotation to return to the original coordinate space.

Parameters:

Name Type Description Default
q Tensor

Query [batch, H_Q, seq_q, head_dim] fp16/bf16.

required
k_packed Tensor

Nibble-packed key indices [batch, H_KV, seq_kv, D//2] uint8.

required
k_norms Tensor

Key norms [batch, H_KV, seq_kv] or [..., 1] fp32.

required
v_packed Tensor

Nibble-packed value indices [batch, H_KV, seq_kv, D//2] uint8.

required
v_norms Tensor

Value norms [batch, H_KV, seq_kv] or [..., 1] fp32.

required
centroids Tensor

Shared Lloyd-Max codebook [16] fp32.

required
rotation Tensor

Shared orthogonal rotation [head_dim, head_dim] fp32.

required
sm_scale float | None

Softmax scale. Defaults to 1 / sqrt(head_dim).

None
is_causal bool

Apply causal masking.

False

Returns:

Type Description
Tensor

Attention output [batch, H_Q, seq_q, head_dim] in original space.

Source code in src/turboquant_vllm/triton/flash_attention_tq4_kv.py
def triton_flash_attention_tq4_kv(
    q: torch.Tensor,
    k_packed: torch.Tensor,
    k_norms: torch.Tensor,
    v_packed: torch.Tensor,
    v_norms: torch.Tensor,
    centroids: torch.Tensor,
    rotation: torch.Tensor,
    sm_scale: float | None = None,
    is_causal: bool = False,
) -> torch.Tensor:
    """Fused TQ4 Flash Attention with both K and V compressed.

    Pre-rotates Q by ``rotation^T``, launches the kernel that decompresses
    both K and V inline, then post-rotates the output by ``rotation`` to
    return to the original coordinate space.

    Args:
        q: Query ``[batch, H_Q, seq_q, head_dim]`` fp16/bf16.
        k_packed: Nibble-packed key indices ``[batch, H_KV, seq_kv, D//2]`` uint8.
        k_norms: Key norms ``[batch, H_KV, seq_kv]`` or ``[..., 1]`` fp32.
        v_packed: Nibble-packed value indices ``[batch, H_KV, seq_kv, D//2]`` uint8.
        v_norms: Value norms ``[batch, H_KV, seq_kv]`` or ``[..., 1]`` fp32.
        centroids: Shared Lloyd-Max codebook ``[16]`` fp32.
        rotation: Shared orthogonal rotation ``[head_dim, head_dim]`` fp32.
        sm_scale: Softmax scale. Defaults to ``1 / sqrt(head_dim)``.
        is_causal: Apply causal masking.

    Returns:
        Attention output ``[batch, H_Q, seq_q, head_dim]`` in original space.
    """
    B, H_Q, N_Q, D = q.shape
    _, H_KV, N_KV, HALF_D = k_packed.shape

    assert HALF_D == D // 2
    assert H_Q % H_KV == 0
    assert k_packed.dtype == torch.uint8
    assert v_packed.dtype == torch.uint8
    assert k_norms.dtype == torch.float32
    assert v_norms.dtype == torch.float32

    # Squeeze trailing 1 from norms
    if k_norms.dim() == 4 and k_norms.shape[-1] == 1:
        k_norms = k_norms.squeeze(-1)
    if v_norms.dim() == 4 and v_norms.shape[-1] == 1:
        v_norms = v_norms.squeeze(-1)

    if is_causal and N_Q == 1:
        is_causal = False

    if sm_scale is None:
        sm_scale = 1.0 / math.sqrt(D)

    # Pre-rotate Q
    q_rot = torch.matmul(q.float(), rotation.T).to(q.dtype)

    out_rot = torch.empty_like(q)

    def grid(META: dict) -> tuple[int, int]:
        """Compute launch grid.

        Returns:
            ``(num_q_blocks, batch * H_Q)`` grid dimensions.
        """
        return (triton.cdiv(N_Q, META["BLOCK_M"]), B * H_Q)

    _fwd_tq4_kv_kernel[grid](
        q_rot,
        k_packed,
        k_norms,
        v_packed,
        v_norms,
        centroids,
        out_rot,
        sm_scale,
        *q_rot.stride(),
        *k_packed.stride(),
        *k_norms.stride(),
        *v_packed.stride(),
        *v_norms.stride(),
        *out_rot.stride(),
        H_Q,
        H_KV,
        N_Q,
        N_KV,
        HEAD_DIM=D,
        IS_CAUSAL=is_causal,
    )

    # Post-rotate: convert from rotated space back to original space
    return torch.matmul(out_rot.float(), rotation).to(q.dtype)

fused_qk_scores

fused_qk_scores(
    q_rotated: Tensor,
    packed_indices: Tensor,
    norms: Tensor,
    centroids: Tensor,
    scale: float,
    *,
    n_q_heads: int,
    n_kv_heads: int,
) -> Tensor

Compute attention scores from pre-rotated queries and nibble-packed keys.

The query must be pre-rotated by the TurboQuant rotation matrix (q_rot = q @ Pi_T) so the kernel avoids the expensive 128x128 rotation matmul in its inner loop.

Parameters:

Name Type Description Default
q_rotated Tensor

Pre-rotated query, shape (batch, n_q_heads, q_len, head_dim).

required
packed_indices Tensor

Nibble-packed 4-bit key indices, shape (batch, n_kv_heads, kv_len, head_dim // 2), dtype uint8.

required
norms Tensor

Key vector norms, shape (batch, n_kv_heads, kv_len), dtype float32.

required
centroids Tensor

Lloyd-Max centroid values, shape (n_levels,), dtype float32.

required
scale float

Attention scale factor (typically 1 / sqrt(head_dim)).

required
n_q_heads int

Number of query attention heads.

required
n_kv_heads int

Number of key-value attention heads.

required

Returns:

Type Description
Tensor

Attention scores, shape (batch, n_q_heads, q_len, kv_len).

Source code in src/turboquant_vllm/triton/fused_qk_attention.py
def fused_qk_scores(
    q_rotated: torch.Tensor,
    packed_indices: torch.Tensor,
    norms: torch.Tensor,
    centroids: torch.Tensor,
    scale: float,
    *,
    n_q_heads: int,
    n_kv_heads: int,
) -> torch.Tensor:
    """Compute attention scores from pre-rotated queries and nibble-packed keys.

    The query must be pre-rotated by the TurboQuant rotation matrix
    (``q_rot = q @ Pi_T``) so the kernel avoids the expensive 128x128
    rotation matmul in its inner loop.

    Args:
        q_rotated: Pre-rotated query, shape
            ``(batch, n_q_heads, q_len, head_dim)``.
        packed_indices: Nibble-packed 4-bit key indices, shape
            ``(batch, n_kv_heads, kv_len, head_dim // 2)``, dtype uint8.
        norms: Key vector norms, shape
            ``(batch, n_kv_heads, kv_len)``, dtype float32.
        centroids: Lloyd-Max centroid values, shape ``(n_levels,)``,
            dtype float32.
        scale: Attention scale factor (typically ``1 / sqrt(head_dim)``).
        n_q_heads: Number of query attention heads.
        n_kv_heads: Number of key-value attention heads.

    Returns:
        Attention scores, shape ``(batch, n_q_heads, q_len, kv_len)``.
    """
    batch, _, q_len, head_dim = q_rotated.shape
    _, _, kv_len, _ = packed_indices.shape

    ki_flat = packed_indices.reshape(
        batch * n_kv_heads, kv_len, head_dim // 2
    ).contiguous()
    kn_flat = norms.reshape(batch * n_kv_heads, kv_len).contiguous()
    centroids = centroids.contiguous().float()

    # For q_len > 1 (prefill), process each query position separately
    # to keep the GQA head mapping correct. The kernel maps
    # q_head → kv_head via gqa_ratio = n_q_heads // n_kv_heads.
    # Flattening q_len into heads would break this mapping.
    results = []
    for q_pos in range(q_len):
        q_slice = (
            q_rotated[:, :, q_pos : q_pos + 1, :]
            .reshape(batch * n_q_heads, head_dim)
            .contiguous()
        )

        out = torch.empty(
            batch * n_q_heads,
            kv_len,
            device=q_rotated.device,
            dtype=torch.float32,
        )

        grid = (batch * n_q_heads, triton.cdiv(kv_len, 64))

        _fused_qk_nibble_kernel[grid](
            q_slice,
            ki_flat,
            kn_flat,
            centroids,
            out,
            kv_len,
            head_dim,
            n_q_heads,
            n_kv_heads,
            scale,
            q_slice.stride(0),
            q_slice.stride(1),
            ki_flat.stride(0),
            ki_flat.stride(1),
            ki_flat.stride(2),
            kn_flat.stride(0),
            kn_flat.stride(1),
            out.stride(0),
            out.stride(1),
        )

        results.append(out.reshape(batch, n_q_heads, 1, kv_len))

    return torch.cat(results, dim=2)