Skip to content

tq4_decompress

turboquant_vllm.triton.tq4_decompress

Fused Triton kernel for TQ4 cache decompression (no rotation).

Phase 3c.8: Replaces the multi-op PyTorch decompress path with a single fused kernel that performs nibble unpack -> centroid gather -> norm scale -> dtype cast in one launch. The rotation is not applied here -- the caller pre-rotates Q by Pi^T and post-rotates the attention output by Pi, saving O(cache_len) matmuls per decode step. Non-power-of-two HEAD_DIM (e.g., 96) is supported via padded tl.arange + masking.

Experiment 015 showed that at cache_len=4096 on RTX 4090, decompress accounts for 68% of decode time. The rotation matmul (128x128) is the dominant cost within that 68%. By moving rotation to Q/output (O(1) per decode step), the kernel only needs elementwise + gather ops.

Attributes:

Name Type Description
tq4_decompress Tensor

Python wrapper that launches the fused kernel.

Examples:

from turboquant_vllm.triton.tq4_decompress import tq4_decompress

# packed: (N, H, D//2) uint8, norms: (N, H, 1) fp32
out = tq4_decompress(packed, norms, centroids, dtype=torch.float16)
# out: (N, H, D) fp16 -- still in rotated space (no Pi applied)
See Also

:mod:turboquant_vllm.triton.flash_attention_tq4: Phase 2 fused FA+K kernel. :mod:turboquant_vllm.vllm.tq4_backend: vLLM backend that calls this kernel.

Functions

tq4_decompress

tq4_decompress(
    packed: Tensor,
    norms: Tensor,
    centroids: Tensor,
    dtype: dtype = float16,
    out: Tensor | None = None,
) -> Tensor

Decompress TQ4 nibble-packed data to full-precision vectors.

Fused Triton path: unpack + centroid gather + norm scale + cast in a single kernel launch. Does not apply the rotation matrix -- output remains in rotated space. Non-power-of-two head dimensions (e.g., 96) are supported via padded tile loads and boundary masking. Non-pow2 dims incur ~5-15 % throughput penalty due to wasted lanes in padded tiles.

Parameters:

Name Type Description Default
packed Tensor

(N, H, D//2) uint8 -- nibble-packed centroid indices.

required
norms Tensor

(N, H, 1) fp32 -- per-vector norms.

required
centroids Tensor

(C,) fp32 -- centroid table (C=16 for TQ4).

required
dtype dtype

Output dtype (default: torch.float16).

float16
out Tensor | None

Optional pre-allocated (N, H, D) output tensor. When provided, results are written into it and the same tensor is returned. Follows PyTorch out convention.

None

Returns:

Type Description
Tensor

Tensor of shape (N, H, D) in dtype, still in rotated

Tensor

space (caller must apply post-rotation if needed).

Source code in src/turboquant_vllm/triton/tq4_decompress.py
def tq4_decompress(
    packed: torch.Tensor,
    norms: torch.Tensor,
    centroids: torch.Tensor,
    dtype: torch.dtype = torch.float16,
    out: torch.Tensor | None = None,
) -> torch.Tensor:
    """Decompress TQ4 nibble-packed data to full-precision vectors.

    Fused Triton path: unpack + centroid gather + norm scale + cast in a
    single kernel launch.  Does **not** apply the rotation matrix --
    output remains in rotated space.  Non-power-of-two head dimensions
    (e.g., 96) are supported via padded tile loads and boundary masking.
    Non-pow2 dims incur ~5-15 % throughput penalty due to wasted lanes
    in padded tiles.

    Args:
        packed: ``(N, H, D//2)`` uint8 -- nibble-packed centroid indices.
        norms: ``(N, H, 1)`` fp32 -- per-vector norms.
        centroids: ``(C,)`` fp32 -- centroid table (C=16 for TQ4).
        dtype: Output dtype (default: ``torch.float16``).
        out: Optional pre-allocated ``(N, H, D)`` output tensor.  When
            provided, results are written into it and the same tensor is
            returned.  Follows PyTorch ``out`` convention.

    Returns:
        Tensor of shape ``(N, H, D)`` in ``dtype``, still in rotated
        space (caller must apply post-rotation if needed).
    """
    N, H, half_D = packed.shape
    D = half_D * 2
    M = N * H

    # CPU fallback: PyTorch path (Triton requires CUDA)
    if not packed.is_cuda:
        return _tq4_decompress_cpu(packed, norms, centroids, dtype, out)

    # Flatten to 2D for the kernel
    packed_flat = packed.reshape(M, half_D).contiguous()
    norms_flat = norms.reshape(M).contiguous()

    caller_out = out
    if out is None:
        out = torch.empty(M, D, dtype=dtype, device=packed.device)

    grid = (M,)
    _tq4_decompress_kernel[grid](
        packed_flat,
        norms_flat,
        centroids,
        out,
        M,
        HALF_D=half_D,  # ty: ignore[invalid-argument-type]
        HALF_D_PAD=_next_pow2(half_D),  # ty: ignore[invalid-argument-type]
    )

    if caller_out is not None:
        return caller_out
    return out.reshape(N, H, D)