Skip to content

triton

turboquant_vllm.triton

Fused Triton kernels for TurboQuant compressed attention.

Phase 1 (P5): Vanilla Flash Attention kernel with GQA support. Phase 2 (P5): Fused TQ4 K decompression inside the FA inner loop. Phase 3 (P5): Fused TQ4 K+V decompression with post-rotation. Phase 3c.8: Standalone TQ4 cache decompress kernel for vLLM backend. Phase 3c.9: Standalone TQ4 compress kernel for vLLM backend. Phase 3a (D9): Fused paged TQ4 decode attention -- decompresses from page table. Phase 3b (D9): Fused paged TQ4 INT8 prefill -- IMMA tensor core Q@K^T path.

Legacy: Q@K^T-only fused kernel (superseded -- see Key Lesson #7).

Attributes:

Name Type Description
triton_flash_attention Tensor

Vanilla FA forward with online softmax.

triton_flash_attention_tq4 Tensor

Fused TQ4 FA with compressed K tiles.

triton_flash_attention_tq4_kv Tensor

Fused TQ4 FA with compressed K+V tiles.

fused_paged_tq4_decode Tensor

Fused paged TQ4 decode with in-tile decompress.

fused_paged_tq4_int8_prefill Tensor

Fused paged TQ4 INT8 prefill with IMMA.

triton_fa_forward tuple[Tensor, None]

HF AttentionInterface-compatible wrapper.

register_triton_fa None

Register the triton_fa backend globally.

install_triton_fa None

Register and activate vanilla FA on a model.

install_fused_tq4_kv None

Activate fused TQ4 K+V with cache side-channel.

uninstall_fused_tq4_kv None

Remove fused attention and restore SDPA.

tq4_compress None

Fused TQ4 compress (norm+rotate+quantize+pack).

tq4_decompress None

Fused TQ4 decompress (unpack+gather+scale).

fused_qk_scores Tensor

Legacy Q@K^T-only kernel (kept for reference).

Examples:

Direct kernel usage:

from turboquant_vllm.triton import triton_flash_attention

out = triton_flash_attention(q, k, v)

HuggingFace integration:

from turboquant_vllm.triton import install_triton_fa

install_triton_fa(model)
output = model.generate(...)
See Also

:mod:turboquant_vllm.kv_cache: CompressedDynamicCache storage layer.

Functions

install_fused_tq4_kv

install_fused_tq4_kv(model: Module, cache: CompressedDynamicCache) -> None

Activate fused TQ4 K+V attention on model with cache side-channel.

Registers the triton_fa_tq4_kv backend, stashes cache on each attention layer as module._tq4_cache, sets the model's _attn_implementation, and enables fused_mode on the cache to skip wasted decompression (P5b optimization).

Parameters:

Name Type Description Default
model Module

HuggingFace model with attention layers that have layer_idx.

required
cache CompressedDynamicCache

CompressedDynamicCache instance that stores compressed K/V.

required

Raises:

Type Description
AttributeError

If model has no config attribute.

Source code in src/turboquant_vllm/triton/attention_interface.py
def install_fused_tq4_kv(model: torch.nn.Module, cache: CompressedDynamicCache) -> None:
    """Activate fused TQ4 K+V attention on *model* with cache side-channel.

    Registers the ``triton_fa_tq4_kv`` backend, stashes *cache* on each
    attention layer as ``module._tq4_cache``, sets the model's
    ``_attn_implementation``, and enables ``fused_mode`` on the cache
    to skip wasted decompression (P5b optimization).

    Args:
        model: HuggingFace model with attention layers that have ``layer_idx``.
        cache: CompressedDynamicCache instance that stores compressed K/V.

    Raises:
        AttributeError: If *model* has no ``config`` attribute.
    """
    ALL_ATTENTION_FUNCTIONS.register("triton_fa_tq4_kv", triton_fa_tq4_kv_forward)

    config = getattr(model, "config", None)
    if config is None:
        msg = "Model has no config attribute"
        raise AttributeError(msg)
    config._attn_implementation = "triton_fa_tq4_kv"

    # Enable fused mode: skip decompression in cache.update()
    cache.fused_mode = True

    # Stash cache reference on each attention layer
    for module in model.modules():
        if hasattr(module, "layer_idx"):
            object.__setattr__(module, "_tq4_cache", cache)

install_triton_fa

install_triton_fa(model: Module) -> None

Register the backend and activate it on model.

Changes model.config._attn_implementation to "triton_fa". The model resolves the attention function at forward time, so this takes effect on the next forward call.

Parameters:

Name Type Description Default
model Module

A HuggingFace model with a config attribute.

required

Raises:

Type Description
AttributeError

If model has no config attribute.

Source code in src/turboquant_vllm/triton/attention_interface.py
def install_triton_fa(model: torch.nn.Module) -> None:
    """Register the backend and activate it on *model*.

    Changes ``model.config._attn_implementation`` to ``"triton_fa"``.
    The model resolves the attention function at forward time, so this
    takes effect on the next forward call.

    Args:
        model: A HuggingFace model with a ``config`` attribute.

    Raises:
        AttributeError: If *model* has no ``config`` attribute.
    """
    register_triton_fa()
    config = getattr(model, "config", None)
    if config is None:
        msg = "Model has no config attribute"
        raise AttributeError(msg)
    config._attn_implementation = "triton_fa"

register_triton_fa

register_triton_fa() -> None

Register triton_fa as a global attention backend in HuggingFace.

Safe to call multiple times -- overwrites the previous registration.

Source code in src/turboquant_vllm/triton/attention_interface.py
def register_triton_fa() -> None:
    """Register ``triton_fa`` as a global attention backend in HuggingFace.

    Safe to call multiple times -- overwrites the previous registration.
    """
    ALL_ATTENTION_FUNCTIONS.register("triton_fa", triton_fa_forward)

triton_fa_forward

triton_fa_forward(
    module: Module,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    attention_mask: Optional[Tensor],
    dropout: float = 0.0,
    scaling: Optional[float] = None,
    **kwargs: object,
) -> tuple[Tensor, None]

HF-compatible attention forward using Triton Flash Attention.

Signature matches transformers.integrations.sdpa_attention.sdpa_attention_forward. Handles GQA natively (no KV repeat expansion needed).

Parameters:

Name Type Description Default
module Module

The attention layer module. Used to read is_causal attribute.

required
query Tensor

[batch, num_q_heads, seq_q, head_dim].

required
key Tensor

[batch, num_kv_heads, seq_kv, head_dim].

required
value Tensor

[batch, num_kv_heads, seq_kv, head_dim].

required
attention_mask Optional[Tensor]

Optional additive mask [batch, 1|heads, seq_q, seq_kv].

required
dropout float

Dropout rate (must be 0 -- inference only).

0.0
scaling Optional[float]

Softmax scale. Defaults to 1 / sqrt(head_dim).

None

Other Parameters:

Name Type Description
is_causal bool | None

Override causal mode. If None, inferred from query.shape[2] and module.is_causal.

**kwargs object

Additional model-specific arguments (ignored).

Returns:

Type Description
Tensor

(output, None) where output is [batch, seq_q, num_q_heads, head_dim]

None

(transposed to match HF convention).

Source code in src/turboquant_vllm/triton/attention_interface.py
def triton_fa_forward(
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    dropout: float = 0.0,
    scaling: Optional[float] = None,
    **kwargs: object,
) -> tuple[torch.Tensor, None]:
    """HF-compatible attention forward using Triton Flash Attention.

    Signature matches ``transformers.integrations.sdpa_attention.sdpa_attention_forward``.
    Handles GQA natively (no KV repeat expansion needed).

    Args:
        module: The attention layer module. Used to read ``is_causal`` attribute.
        query: ``[batch, num_q_heads, seq_q, head_dim]``.
        key: ``[batch, num_kv_heads, seq_kv, head_dim]``.
        value: ``[batch, num_kv_heads, seq_kv, head_dim]``.
        attention_mask: Optional additive mask ``[batch, 1|heads, seq_q, seq_kv]``.
        dropout: Dropout rate (must be 0 -- inference only).
        scaling: Softmax scale. Defaults to ``1 / sqrt(head_dim)``.

    Other Parameters:
        is_causal (bool | None): Override causal mode. If ``None``, inferred
            from ``query.shape[2]`` and ``module.is_causal``.
        **kwargs: Additional model-specific arguments (ignored).

    Returns:
        ``(output, None)`` where output is ``[batch, seq_q, num_q_heads, head_dim]``
        (transposed to match HF convention).
    """
    # Determine causal mode (same logic as HF SDPA backend)
    is_causal_raw = kwargs.pop("is_causal", None)
    is_causal_flag: bool
    if is_causal_raw is None:
        is_causal_flag = bool(
            query.shape[2] > 1
            and attention_mask is None
            and getattr(module, "is_causal", True)
        )
    else:
        is_causal_flag = bool(is_causal_raw)

    out = triton_flash_attention(
        query,
        key,
        value,
        sm_scale=scaling,
        is_causal=is_causal_flag,
        attention_mask=attention_mask if not is_causal_flag else None,
    )

    # Transpose to [batch, seq, heads, head_dim] per HF convention
    out = out.transpose(1, 2).contiguous()
    return out, None

triton_fa_tq4_kv_forward

triton_fa_tq4_kv_forward(
    module: Module,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    attention_mask: Optional[Tensor],
    dropout: float = 0.0,
    scaling: Optional[float] = None,
    **kwargs: object,
) -> tuple[Tensor, None]

Fused TQ4 K+V attention via cache side-channel.

Reads compressed K/V from the CompressedDynamicCache stashed on module._tq4_cache (ignoring the decompressed key/value args). Falls back to vanilla Triton FA if no cache reference is found.

Parameters:

Name Type Description Default
module Module

Attention layer with layer_idx and _tq4_cache attrs.

required
query Tensor

[batch, H_Q, seq_q, head_dim] (RoPE already applied).

required
key Tensor

[batch, H_KV, seq_kv, head_dim] — ignored when fused.

required
value Tensor

[batch, H_KV, seq_kv, head_dim] — ignored when fused.

required
attention_mask Optional[Tensor]

Optional additive mask.

required
dropout float

Must be 0 (inference only).

0.0
scaling Optional[float]

Softmax scale.

None

Other Parameters:

Name Type Description
is_causal bool | None

Override causal mode.

**kwargs object

Additional model-specific arguments (ignored).

Returns:

Type Description
tuple[Tensor, None]

(output, None) with output [batch, seq_q, H_Q, head_dim].

Source code in src/turboquant_vllm/triton/attention_interface.py
def triton_fa_tq4_kv_forward(
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    dropout: float = 0.0,
    scaling: Optional[float] = None,
    **kwargs: object,
) -> tuple[torch.Tensor, None]:
    """Fused TQ4 K+V attention via cache side-channel.

    Reads compressed K/V from the ``CompressedDynamicCache`` stashed on
    ``module._tq4_cache`` (ignoring the decompressed *key*/*value* args).
    Falls back to vanilla Triton FA if no cache reference is found.

    Args:
        module: Attention layer with ``layer_idx`` and ``_tq4_cache`` attrs.
        query: ``[batch, H_Q, seq_q, head_dim]`` (RoPE already applied).
        key: ``[batch, H_KV, seq_kv, head_dim]`` — ignored when fused.
        value: ``[batch, H_KV, seq_kv, head_dim]`` — ignored when fused.
        attention_mask: Optional additive mask.
        dropout: Must be 0 (inference only).
        scaling: Softmax scale.

    Other Parameters:
        is_causal (bool | None): Override causal mode.
        **kwargs: Additional model-specific arguments (ignored).

    Returns:
        ``(output, None)`` with output ``[batch, seq_q, H_Q, head_dim]``.
    """
    cache: CompressedDynamicCache | None = getattr(module, "_tq4_cache", None)
    layer_idx: int | None = getattr(module, "layer_idx", None)

    # Fallback: no cache stash → use vanilla Triton FA with decompressed K/V
    if cache is None or layer_idx is None:
        return triton_fa_forward(
            module, query, key, value, attention_mask, dropout, scaling, **kwargs
        )

    # Read compressed K/V from cache side-channel
    k_packed, k_norms, v_packed, v_norms = cache.get_compressed(layer_idx)
    rotation = cache.rotation.to(device=query.device)
    centroids = cache.centroids.to(device=query.device)

    # Determine causal mode
    is_causal_raw = kwargs.pop("is_causal", None)
    is_causal_flag: bool
    if is_causal_raw is None:
        is_causal_flag = bool(
            query.shape[2] > 1
            and attention_mask is None
            and getattr(module, "is_causal", True)
        )
    else:
        is_causal_flag = bool(is_causal_raw)

    out = triton_flash_attention_tq4_kv(
        query,
        k_packed,
        k_norms,
        v_packed,
        v_norms,
        centroids,
        rotation,
        sm_scale=scaling,
        is_causal=is_causal_flag,
    )

    out = out.transpose(1, 2).contiguous()
    return out, None

uninstall_fused_tq4_kv

uninstall_fused_tq4_kv(model: Module) -> None

Remove fused TQ4 attention and restore SDPA.

Removes _tq4_cache from attention layers, disables fused_mode on the cache, and resets _attn_implementation to "sdpa".

Parameters:

Name Type Description Default
model Module

Model previously configured with install_fused_tq4_kv.

required
Source code in src/turboquant_vllm/triton/attention_interface.py
def uninstall_fused_tq4_kv(model: torch.nn.Module) -> None:
    """Remove fused TQ4 attention and restore SDPA.

    Removes ``_tq4_cache`` from attention layers, disables ``fused_mode``
    on the cache, and resets ``_attn_implementation`` to ``"sdpa"``.

    Args:
        model: Model previously configured with ``install_fused_tq4_kv``.
    """
    config = getattr(model, "config", None)
    if config is not None:
        config._attn_implementation = "sdpa"

    for module in model.modules():
        if hasattr(module, "_tq4_cache"):
            cache = getattr(module, "_tq4_cache", None)
            if cache is not None and hasattr(cache, "fused_mode"):
                cache.fused_mode = False  # type: ignore[union-attr]
            if hasattr(module, "_tq4_cache"):
                delattr(module, "_tq4_cache")

triton_flash_attention

triton_flash_attention(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    sm_scale: float | None = None,
    is_causal: bool = False,
    attention_mask: Tensor | None = None,
) -> Tensor

Compute scaled dot-product attention using Triton Flash Attention.

Non-power-of-two head dimensions (e.g., 96) are supported via padded tile loads and boundary masking inside the kernel.

Parameters:

Name Type Description Default
q Tensor

Query tensor [batch, num_q_heads, seq_q, head_dim].

required
k Tensor

Key tensor [batch, num_kv_heads, seq_kv, head_dim].

required
v Tensor

Value tensor [batch, num_kv_heads, seq_kv, head_dim].

required
sm_scale float | None

Softmax scale factor. Defaults to 1 / sqrt(head_dim).

None
is_causal bool

Apply causal masking. Only valid when seq_q >= seq_kv (prefill). For decode (seq_q == 1), forced to False.

False
attention_mask Tensor | None

Optional additive mask [batch, 1|heads, seq_q, seq_kv]. Values of 0 mean attend, large negative values mean block. Mutually exclusive with is_causal in practice (HF sets is_causal only when attention_mask is None).

None

Returns:

Type Description
Tensor

Attention output [batch, num_q_heads, seq_q, head_dim].

Source code in src/turboquant_vllm/triton/flash_attention.py
def triton_flash_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    sm_scale: float | None = None,
    is_causal: bool = False,
    attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
    """Compute scaled dot-product attention using Triton Flash Attention.

    Non-power-of-two head dimensions (e.g., 96) are supported via padded
    tile loads and boundary masking inside the kernel.

    Args:
        q: Query tensor ``[batch, num_q_heads, seq_q, head_dim]``.
        k: Key tensor ``[batch, num_kv_heads, seq_kv, head_dim]``.
        v: Value tensor ``[batch, num_kv_heads, seq_kv, head_dim]``.
        sm_scale: Softmax scale factor. Defaults to ``1 / sqrt(head_dim)``.
        is_causal: Apply causal masking. Only valid when ``seq_q >= seq_kv``
            (prefill). For decode (``seq_q == 1``), forced to ``False``.
        attention_mask: Optional additive mask
            ``[batch, 1|heads, seq_q, seq_kv]``. Values of 0 mean attend,
            large negative values mean block. Mutually exclusive with
            ``is_causal`` in practice (HF sets ``is_causal`` only when
            ``attention_mask is None``).

    Returns:
        Attention output ``[batch, num_q_heads, seq_q, head_dim]``.
    """
    B, H_Q, N_Q, D = q.shape
    _, H_KV, N_KV, _ = k.shape

    assert q.dtype == k.dtype == v.dtype, "Q, K, V must have the same dtype"
    assert q.dtype in (
        torch.float16,
        torch.bfloat16,
    ), f"Triton FA requires fp16 or bf16, got {q.dtype}"
    assert H_Q % H_KV == 0, f"Q heads ({H_Q}) must be divisible by KV heads ({H_KV})"
    assert k.shape[2] == v.shape[2], "K and V must have the same sequence length"
    assert q.shape[3] == k.shape[3] == v.shape[3], "Head dimensions must match"

    # Single query token never needs causal masking
    if is_causal and N_Q == 1:
        is_causal = False

    if sm_scale is None:
        sm_scale = 1.0 / math.sqrt(D)

    out = torch.empty_like(q)

    # Handle optional mask
    has_mask = attention_mask is not None
    if has_mask:
        assert attention_mask is not None  # for type narrowing
        mask = attention_mask
        stride_mz, stride_mh, stride_mm, stride_mn = mask.stride()
        # Handle broadcast: size-1 dims have nonzero stride in PyTorch
        # but must be treated as stride=0 for correct broadcast behavior.
        if mask.shape[0] == 1:
            stride_mz = 0
        if mask.shape[1] == 1:
            stride_mh = 0
    else:
        mask = q  # dummy pointer, never dereferenced
        stride_mz = stride_mh = stride_mm = stride_mn = 0

    # Grid: one CTA per (Q-block, batch*q_head)
    def grid(META: dict) -> tuple[int, int]:
        """Compute kernel launch grid from autotuned block size.

        Returns:
            ``(num_q_blocks, batch * num_q_heads)`` grid dimensions.
        """
        return (triton.cdiv(N_Q, META["BLOCK_M"]), B * H_Q)

    _fwd_kernel[grid](
        q,
        k,
        v,
        out,
        mask,
        sm_scale,
        *q.stride(),
        *k.stride(),
        *v.stride(),
        *out.stride(),
        stride_mz,
        stride_mh,
        stride_mm,
        stride_mn,
        H_Q,
        H_KV,
        N_Q,
        N_KV,
        HEAD_DIM=D,
        HEAD_DIM_PAD=_next_pow2(D),
        IS_CAUSAL=is_causal,
        HAS_MASK=has_mask,
    )

    return out

triton_flash_attention_tq4

triton_flash_attention_tq4(
    q: Tensor,
    k_packed: Tensor,
    k_norms: Tensor,
    centroids: Tensor,
    rotation: Tensor,
    v: Tensor,
    sm_scale: float | None = None,
    is_causal: bool = False,
) -> Tensor

Fused TQ4 Flash Attention with compressed K and standard V.

Pre-rotates Q by rotation^T, then launches the fused kernel that decompresses nibble-packed K indices inline via centroid gather. Non-power-of-two head dimensions (e.g., 96) are handled via padded tile loads and boundary masking inside the kernel. Non-pow2 dims incur ~5-15 % throughput penalty due to wasted lanes in padded tiles (e.g., head_dim=96 pads to 128, wasting 25 % of memory bandwidth on K/V loads).

Parameters:

Name Type Description Default
q Tensor

Query [batch, H_Q, seq_q, head_dim] fp16/bf16.

required
k_packed Tensor

Nibble-packed key indices [batch, H_KV, seq_kv, head_dim//2] uint8.

required
k_norms Tensor

Key norms [batch, H_KV, seq_kv] or [..., 1] fp32.

required
centroids Tensor

Lloyd-Max codebook [16] fp32 (for 4-bit).

required
rotation Tensor

Orthogonal rotation matrix [head_dim, head_dim] fp32.

required
v Tensor

Values [batch, H_KV, seq_kv, head_dim] fp16/bf16.

required
sm_scale float | None

Softmax scale. Defaults to 1 / sqrt(head_dim).

None
is_causal bool

Apply causal masking.

False

Returns:

Type Description
Tensor

Attention output [batch, H_Q, seq_q, head_dim].

Source code in src/turboquant_vllm/triton/flash_attention_tq4.py
def triton_flash_attention_tq4(
    q: torch.Tensor,
    k_packed: torch.Tensor,
    k_norms: torch.Tensor,
    centroids: torch.Tensor,
    rotation: torch.Tensor,
    v: torch.Tensor,
    sm_scale: float | None = None,
    is_causal: bool = False,
) -> torch.Tensor:
    """Fused TQ4 Flash Attention with compressed K and standard V.

    Pre-rotates Q by ``rotation^T``, then launches the fused kernel that
    decompresses nibble-packed K indices inline via centroid gather.
    Non-power-of-two head dimensions (e.g., 96) are handled via padded
    tile loads and boundary masking inside the kernel.  Non-pow2 dims incur
    ~5-15 % throughput penalty due to wasted lanes in padded tiles (e.g.,
    head_dim=96 pads to 128, wasting 25 % of memory bandwidth on K/V loads).

    Args:
        q: Query ``[batch, H_Q, seq_q, head_dim]`` fp16/bf16.
        k_packed: Nibble-packed key indices ``[batch, H_KV, seq_kv, head_dim//2]`` uint8.
        k_norms: Key norms ``[batch, H_KV, seq_kv]`` or ``[..., 1]`` fp32.
        centroids: Lloyd-Max codebook ``[16]`` fp32 (for 4-bit).
        rotation: Orthogonal rotation matrix ``[head_dim, head_dim]`` fp32.
        v: Values ``[batch, H_KV, seq_kv, head_dim]`` fp16/bf16.
        sm_scale: Softmax scale. Defaults to ``1 / sqrt(head_dim)``.
        is_causal: Apply causal masking.

    Returns:
        Attention output ``[batch, H_Q, seq_q, head_dim]``.
    """
    B, H_Q, N_Q, D = q.shape
    _, H_KV, N_KV, HALF_D = k_packed.shape

    assert D % 2 == 0, f"HEAD_DIM must be even, got {D}"
    assert HALF_D == D // 2, f"Packed dim {HALF_D} != head_dim//2 ({D // 2})"
    assert H_Q % H_KV == 0, f"Q heads ({H_Q}) must be divisible by KV heads ({H_KV})"
    assert k_packed.dtype == torch.uint8, "k_packed must be uint8"
    assert k_norms.dtype == torch.float32, "k_norms must be float32"
    assert centroids.dtype == torch.float32, "centroids must be float32"

    # Squeeze trailing 1 from norms if present
    if k_norms.dim() == 4 and k_norms.shape[-1] == 1:
        k_norms = k_norms.squeeze(-1)

    if is_causal and N_Q == 1:
        is_causal = False

    if sm_scale is None:
        sm_scale = 1.0 / math.sqrt(D)

    # Pre-rotate Q: q_rot = q @ Pi^T
    q_rot = torch.matmul(q.float(), rotation.T).to(q.dtype)

    out = torch.empty_like(q)

    def grid(META: dict) -> tuple[int, int]:
        """Compute launch grid from autotuned block size.

        Returns:
            ``(num_q_blocks, batch * H_Q)`` grid dimensions.
        """
        return (triton.cdiv(N_Q, META["BLOCK_M"]), B * H_Q)

    HEAD_DIM_PAD = _next_pow2(D)
    HALF_D_PAD = _next_pow2(D // 2)
    assert HALF_D_PAD * 2 == HEAD_DIM_PAD, (
        f"Padding invariant violated: 2*HALF_D_PAD ({2 * HALF_D_PAD}) "
        f"!= HEAD_DIM_PAD ({HEAD_DIM_PAD}) — tl.join reshape requires this"
    )

    _fwd_tq4_kernel[grid](
        q_rot,
        k_packed,
        k_norms,
        centroids,
        v,
        out,
        sm_scale,
        *q_rot.stride(),
        *k_packed.stride(),
        *k_norms.stride(),
        *v.stride(),
        *out.stride(),
        H_Q,
        H_KV,
        N_Q,
        N_KV,
        HEAD_DIM=D,
        HEAD_DIM_PAD=HEAD_DIM_PAD,
        HALF_D_PAD=HALF_D_PAD,
        IS_CAUSAL=is_causal,
    )

    return out

triton_flash_attention_tq4_kv

triton_flash_attention_tq4_kv(
    q: Tensor,
    k_packed: Tensor,
    k_norms: Tensor,
    v_packed: Tensor,
    v_norms: Tensor,
    centroids: Tensor,
    rotation: Tensor,
    sm_scale: float | None = None,
    is_causal: bool = False,
) -> Tensor

Fused TQ4 Flash Attention with both K and V compressed.

Pre-rotates Q by rotation^T, launches the kernel that decompresses both K and V inline, then post-rotates the output by rotation to return to the original coordinate space. Non-power-of-two head dimensions (e.g., 96) are handled via padded tile loads and boundary masking inside the kernel. Non-pow2 dims incur ~5-15 % throughput penalty due to wasted lanes in padded tiles (e.g., head_dim=96 pads to 128, wasting 25 % of memory bandwidth on K/V loads).

Parameters:

Name Type Description Default
q Tensor

Query [batch, H_Q, seq_q, head_dim] fp16/bf16.

required
k_packed Tensor

Nibble-packed key indices [batch, H_KV, seq_kv, D//2] uint8.

required
k_norms Tensor

Key norms [batch, H_KV, seq_kv] or [..., 1] fp32.

required
v_packed Tensor

Nibble-packed value indices [batch, H_KV, seq_kv, D//2] uint8.

required
v_norms Tensor

Value norms [batch, H_KV, seq_kv] or [..., 1] fp32.

required
centroids Tensor

Shared Lloyd-Max codebook [16] fp32.

required
rotation Tensor

Shared orthogonal rotation [head_dim, head_dim] fp32.

required
sm_scale float | None

Softmax scale. Defaults to 1 / sqrt(head_dim).

None
is_causal bool

Apply causal masking.

False

Returns:

Type Description
Tensor

Attention output [batch, H_Q, seq_q, head_dim] in original space.

Source code in src/turboquant_vllm/triton/flash_attention_tq4_kv.py
def triton_flash_attention_tq4_kv(
    q: torch.Tensor,
    k_packed: torch.Tensor,
    k_norms: torch.Tensor,
    v_packed: torch.Tensor,
    v_norms: torch.Tensor,
    centroids: torch.Tensor,
    rotation: torch.Tensor,
    sm_scale: float | None = None,
    is_causal: bool = False,
) -> torch.Tensor:
    """Fused TQ4 Flash Attention with both K and V compressed.

    Pre-rotates Q by ``rotation^T``, launches the kernel that decompresses
    both K and V inline, then post-rotates the output by ``rotation`` to
    return to the original coordinate space. Non-power-of-two head
    dimensions (e.g., 96) are handled via padded tile loads and boundary
    masking inside the kernel.  Non-pow2 dims incur ~5-15 % throughput
    penalty due to wasted lanes in padded tiles (e.g., head_dim=96 pads
    to 128, wasting 25 % of memory bandwidth on K/V loads).

    Args:
        q: Query ``[batch, H_Q, seq_q, head_dim]`` fp16/bf16.
        k_packed: Nibble-packed key indices ``[batch, H_KV, seq_kv, D//2]`` uint8.
        k_norms: Key norms ``[batch, H_KV, seq_kv]`` or ``[..., 1]`` fp32.
        v_packed: Nibble-packed value indices ``[batch, H_KV, seq_kv, D//2]`` uint8.
        v_norms: Value norms ``[batch, H_KV, seq_kv]`` or ``[..., 1]`` fp32.
        centroids: Shared Lloyd-Max codebook ``[16]`` fp32.
        rotation: Shared orthogonal rotation ``[head_dim, head_dim]`` fp32.
        sm_scale: Softmax scale. Defaults to ``1 / sqrt(head_dim)``.
        is_causal: Apply causal masking.

    Returns:
        Attention output ``[batch, H_Q, seq_q, head_dim]`` in original space.
    """
    B, H_Q, N_Q, D = q.shape
    _, H_KV, N_KV, HALF_D = k_packed.shape

    assert D % 2 == 0, f"HEAD_DIM must be even, got {D}"
    assert HALF_D == D // 2
    assert H_Q % H_KV == 0
    assert k_packed.dtype == torch.uint8
    assert v_packed.dtype == torch.uint8
    assert k_norms.dtype == torch.float32
    assert v_norms.dtype == torch.float32

    # Squeeze trailing 1 from norms
    if k_norms.dim() == 4 and k_norms.shape[-1] == 1:
        k_norms = k_norms.squeeze(-1)
    if v_norms.dim() == 4 and v_norms.shape[-1] == 1:
        v_norms = v_norms.squeeze(-1)

    if is_causal and N_Q == 1:
        is_causal = False

    if sm_scale is None:
        sm_scale = 1.0 / math.sqrt(D)

    # Pre-rotate Q
    q_rot = torch.matmul(q.float(), rotation.T).to(q.dtype)

    out_rot = torch.empty_like(q)

    def grid(META: dict) -> tuple[int, int]:
        """Compute launch grid.

        Returns:
            ``(num_q_blocks, batch * H_Q)`` grid dimensions.
        """
        return (triton.cdiv(N_Q, META["BLOCK_M"]), B * H_Q)

    HEAD_DIM_PAD = _next_pow2(D)
    HALF_D_PAD = _next_pow2(D // 2)
    assert HALF_D_PAD * 2 == HEAD_DIM_PAD, (
        f"Padding invariant violated: 2*HALF_D_PAD ({2 * HALF_D_PAD}) "
        f"!= HEAD_DIM_PAD ({HEAD_DIM_PAD}) — tl.join reshape requires this"
    )

    _fwd_tq4_kv_kernel[grid](
        q_rot,
        k_packed,
        k_norms,
        v_packed,
        v_norms,
        centroids,
        out_rot,
        sm_scale,
        *q_rot.stride(),
        *k_packed.stride(),
        *k_norms.stride(),
        *v_packed.stride(),
        *v_norms.stride(),
        *out_rot.stride(),
        H_Q,
        H_KV,
        N_Q,
        N_KV,
        HEAD_DIM=D,
        HEAD_DIM_PAD=HEAD_DIM_PAD,
        HALF_D_PAD=HALF_D_PAD,
        IS_CAUSAL=is_causal,
    )

    # Post-rotate: convert from rotated space back to original space
    return torch.matmul(out_rot.float(), rotation).to(q.dtype)

fused_paged_tq4_decode

fused_paged_tq4_decode(
    q: Tensor,
    kv_cache: Tensor,
    block_table: Tensor,
    seq_lens: Tensor,
    centroids: Tensor,
    rotation: Tensor,
    num_kv_heads: int,
    head_dim: int,
    block_size: int,
    sm_scale: float | None = None,
    out: Tensor | None = None,
) -> Tensor

Fused paged TQ4 decode attention.

Pre-rotates Q by rotation^T, launches the fused paged kernel that decompresses TQ4 blocks in-tile from the page table, then post-rotates the output by rotation to return to original space.

Parameters:

Name Type Description Default
q Tensor

Query [num_seqs, H_Q, head_dim] fp16/bf16 (one token per seq).

required
kv_cache Tensor

Packed paged cache [num_blocks, block_size, total_bytes] uint8.

required
block_table Tensor

Page table [num_seqs, max_num_blocks_per_seq] int32.

required
seq_lens Tensor

Sequence lengths [num_seqs] int32.

required
centroids Tensor

TQ4 codebook [16] fp32.

required
rotation Tensor

Orthogonal rotation [head_dim, head_dim] fp32.

required
num_kv_heads int

Number of KV heads.

required
head_dim int

Head dimension (e.g. 128).

required
block_size int

vLLM page size (tokens per block).

required
sm_scale float | None

Softmax scale. Defaults to 1 / sqrt(head_dim).

None
out Tensor | None

Optional pre-allocated output [num_seqs, H_Q, head_dim]. When provided, the final post-rotated result is copied into this buffer and returned. A private scratch buffer is used internally for the kernel's rotated-space output.

None

Returns:

Type Description
Tensor

Attention output [num_seqs, H_Q, head_dim] in original space.

Tensor

When out is provided, returns out with the result written

Tensor

in-place.

Note

INT8 placeholder parameters (Q_scale, QJL_S, QJL_signs, QJL_residual_norms) should be passed as None/zeros when USE_INT8_QK=False (the default for Phase 3a).

Source code in src/turboquant_vllm/triton/fused_paged_tq4_attention.py
def fused_paged_tq4_decode(
    q: torch.Tensor,
    kv_cache: torch.Tensor,
    block_table: torch.Tensor,
    seq_lens: torch.Tensor,
    centroids: torch.Tensor,
    rotation: torch.Tensor,
    num_kv_heads: int,
    head_dim: int,
    block_size: int,
    sm_scale: float | None = None,
    out: torch.Tensor | None = None,
) -> torch.Tensor:
    """Fused paged TQ4 decode attention.

    Pre-rotates Q by ``rotation^T``, launches the fused paged kernel
    that decompresses TQ4 blocks in-tile from the page table, then
    post-rotates the output by ``rotation`` to return to original space.

    Args:
        q: Query ``[num_seqs, H_Q, head_dim]`` fp16/bf16 (one token per seq).
        kv_cache: Packed paged cache ``[num_blocks, block_size, total_bytes]``
            uint8.
        block_table: Page table ``[num_seqs, max_num_blocks_per_seq]`` int32.
        seq_lens: Sequence lengths ``[num_seqs]`` int32.
        centroids: TQ4 codebook ``[16]`` fp32.
        rotation: Orthogonal rotation ``[head_dim, head_dim]`` fp32.
        num_kv_heads: Number of KV heads.
        head_dim: Head dimension (e.g. 128).
        block_size: vLLM page size (tokens per block).
        sm_scale: Softmax scale.  Defaults to ``1 / sqrt(head_dim)``.
        out: Optional pre-allocated output ``[num_seqs, H_Q, head_dim]``.
            When provided, the final post-rotated result is copied into
            this buffer and returned.  A private scratch buffer is used
            internally for the kernel's rotated-space output.

    Returns:
        Attention output ``[num_seqs, H_Q, head_dim]`` in original space.
        When ``out`` is provided, returns ``out`` with the result written
        in-place.

    Note:
        INT8 placeholder parameters (``Q_scale``, ``QJL_S``, ``QJL_signs``,
        ``QJL_residual_norms``) should be passed as ``None``/zeros when
        ``USE_INT8_QK=False`` (the default for Phase 3a).
    """
    num_seqs, H_Q, D = q.shape

    assert D == head_dim
    assert H_Q % num_kv_heads == 0
    assert kv_cache.dtype == torch.uint8
    assert block_table.dtype == torch.int32
    assert seq_lens.dtype == torch.int32

    if sm_scale is None:
        sm_scale = 1.0 / math.sqrt(head_dim)

    half_D = head_dim // 2

    # Byte layout constexprs
    k_norm_offset = num_kv_heads * half_D
    v_idx_offset = k_norm_offset + num_kv_heads * 4
    v_norm_offset = v_idx_offset + num_kv_heads * half_D

    # Pre-rotate Q by Pi^T (O(num_seqs), not O(cache_len))
    q_rot = torch.matmul(q.float(), rotation.T).to(q.dtype)

    # Always use a private scratch buffer for the kernel's rotated-space output.
    # When ``out`` is provided, the final (post-rotated) result is copied into it.
    out_rot = torch.empty_like(q)

    # INT8 placeholders (unused, compiled out)
    dummy = torch.empty(0, device=q.device)

    grid = (num_seqs, H_Q)

    _fused_paged_tq4_decode_kernel[grid](
        q_rot,
        kv_cache,
        block_table,
        seq_lens,
        centroids,
        out_rot,
        dummy,  # Q_scale
        dummy,  # QJL_S
        dummy,  # QJL_signs
        dummy,  # QJL_residual_norms
        q_rot.stride(0),
        q_rot.stride(1),
        q_rot.stride(2),
        kv_cache.stride(0),
        kv_cache.stride(1),
        block_table.stride(0),
        block_table.stride(1),
        out_rot.stride(0),
        out_rot.stride(1),
        out_rot.stride(2),
        sm_scale=sm_scale,
        H_Q=H_Q,
        H_KV=num_kv_heads,
        HEAD_DIM=head_dim,
        BLOCK_SIZE=block_size,
        HALF_D=half_D,
        K_NORM_OFFSET=k_norm_offset,
        V_IDX_OFFSET=v_idx_offset,
        V_NORM_OFFSET=v_norm_offset,
        USE_INT8_QK=False,
        QJL_DIM=0,
    )

    # Post-rotate: convert from rotated space back to original space
    result = torch.matmul(out_rot.float(), rotation).to(q.dtype)
    if out is not None:
        out.copy_(result)
        return out
    return result

fused_qk_scores

fused_qk_scores(
    q_rotated: Tensor,
    packed_indices: Tensor,
    norms: Tensor,
    centroids: Tensor,
    scale: float,
    *,
    n_q_heads: int,
    n_kv_heads: int,
) -> Tensor

Compute attention scores from pre-rotated queries and nibble-packed keys.

The query must be pre-rotated by the TurboQuant rotation matrix (q_rot = q @ Pi_T) so the kernel avoids the expensive 128x128 rotation matmul in its inner loop.

Parameters:

Name Type Description Default
q_rotated Tensor

Pre-rotated query, shape (batch, n_q_heads, q_len, head_dim).

required
packed_indices Tensor

Nibble-packed 4-bit key indices, shape (batch, n_kv_heads, kv_len, head_dim // 2), dtype uint8.

required
norms Tensor

Key vector norms, shape (batch, n_kv_heads, kv_len), dtype float32.

required
centroids Tensor

Lloyd-Max centroid values, shape (n_levels,), dtype float32.

required
scale float

Attention scale factor (typically 1 / sqrt(head_dim)).

required
n_q_heads int

Number of query attention heads.

required
n_kv_heads int

Number of key-value attention heads.

required

Returns:

Type Description
Tensor

Attention scores, shape (batch, n_q_heads, q_len, kv_len).

Source code in src/turboquant_vllm/triton/fused_qk_attention.py
def fused_qk_scores(
    q_rotated: torch.Tensor,
    packed_indices: torch.Tensor,
    norms: torch.Tensor,
    centroids: torch.Tensor,
    scale: float,
    *,
    n_q_heads: int,
    n_kv_heads: int,
) -> torch.Tensor:
    """Compute attention scores from pre-rotated queries and nibble-packed keys.

    The query must be pre-rotated by the TurboQuant rotation matrix
    (``q_rot = q @ Pi_T``) so the kernel avoids the expensive 128x128
    rotation matmul in its inner loop.

    Args:
        q_rotated: Pre-rotated query, shape
            ``(batch, n_q_heads, q_len, head_dim)``.
        packed_indices: Nibble-packed 4-bit key indices, shape
            ``(batch, n_kv_heads, kv_len, head_dim // 2)``, dtype uint8.
        norms: Key vector norms, shape
            ``(batch, n_kv_heads, kv_len)``, dtype float32.
        centroids: Lloyd-Max centroid values, shape ``(n_levels,)``,
            dtype float32.
        scale: Attention scale factor (typically ``1 / sqrt(head_dim)``).
        n_q_heads: Number of query attention heads.
        n_kv_heads: Number of key-value attention heads.

    Returns:
        Attention scores, shape ``(batch, n_q_heads, q_len, kv_len)``.
    """
    batch, _, q_len, head_dim = q_rotated.shape
    _, _, kv_len, _ = packed_indices.shape

    ki_flat = packed_indices.reshape(
        batch * n_kv_heads, kv_len, head_dim // 2
    ).contiguous()
    kn_flat = norms.reshape(batch * n_kv_heads, kv_len).contiguous()
    centroids = centroids.contiguous().float()

    # For q_len > 1 (prefill), process each query position separately
    # to keep the GQA head mapping correct. The kernel maps
    # q_head → kv_head via gqa_ratio = n_q_heads // n_kv_heads.
    # Flattening q_len into heads would break this mapping.
    results = []
    for q_pos in range(q_len):
        q_slice = (
            q_rotated[:, :, q_pos : q_pos + 1, :]
            .reshape(batch * n_q_heads, head_dim)
            .contiguous()
        )

        out = torch.empty(
            batch * n_q_heads,
            kv_len,
            device=q_rotated.device,
            dtype=torch.float32,
        )

        grid = (batch * n_q_heads, triton.cdiv(kv_len, 64))

        _fused_qk_nibble_kernel[grid](
            q_slice,
            ki_flat,
            kn_flat,
            centroids,
            out,
            kv_len,
            head_dim,
            n_q_heads,
            n_kv_heads,
            scale,
            q_slice.stride(0),
            q_slice.stride(1),
            ki_flat.stride(0),
            ki_flat.stride(1),
            ki_flat.stride(2),
            kn_flat.stride(0),
            kn_flat.stride(1),
            out.stride(0),
            out.stride(1),
        )

        results.append(out.reshape(batch, n_q_heads, 1, kv_len))

    return torch.cat(results, dim=2)