Skip to content

tq4_decompress

turboquant_vllm.triton.tq4_decompress

Fused Triton kernel for TQ4 cache decompression (no rotation).

Phase 3c.8: Replaces the multi-op PyTorch decompress path with a single fused kernel that performs nibble unpack -> centroid gather -> norm scale -> dtype cast in one launch. The rotation is not applied here -- the caller pre-rotates Q by Pi^T and post-rotates the attention output by Pi, saving O(cache_len) matmuls per decode step.

Experiment 015 showed that at cache_len=4096 on RTX 4090, decompress accounts for 68% of decode time. The rotation matmul (128x128) is the dominant cost within that 68%. By moving rotation to Q/output (O(1) per decode step), the kernel only needs elementwise + gather ops.

Attributes:

Name Type Description
tq4_decompress Tensor

Python wrapper that launches the fused kernel.

Examples:

from turboquant_vllm.triton.tq4_decompress import tq4_decompress

# packed: (N, H, D//2) uint8, norms: (N, H, 1) fp32
out = tq4_decompress(packed, norms, centroids, dtype=torch.float16)
# out: (N, H, D) fp16 -- still in rotated space (no Pi applied)
See Also

:mod:turboquant_vllm.triton.flash_attention_tq4: Phase 2 fused FA+K kernel. :mod:turboquant_vllm.vllm.tq4_backend: vLLM backend that calls this kernel.

Functions

tq4_decompress

tq4_decompress(
    packed: Tensor, norms: Tensor, centroids: Tensor, dtype: dtype = float16
) -> Tensor

Decompress TQ4 nibble-packed data to full-precision vectors.

Fused Triton path: unpack + centroid gather + norm scale + cast in a single kernel launch. Does not apply the rotation matrix -- output remains in rotated space.

Parameters:

Name Type Description Default
packed Tensor

(N, H, D//2) uint8 -- nibble-packed centroid indices.

required
norms Tensor

(N, H, 1) fp32 -- per-vector norms.

required
centroids Tensor

(C,) fp32 -- centroid table (C=16 for TQ4).

required
dtype dtype

Output dtype (default: torch.float16).

float16

Returns:

Type Description
Tensor

Tensor of shape (N, H, D) in dtype, still in rotated

Tensor

space (caller must apply post-rotation if needed).

Source code in src/turboquant_vllm/triton/tq4_decompress.py
def tq4_decompress(
    packed: torch.Tensor,
    norms: torch.Tensor,
    centroids: torch.Tensor,
    dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
    """Decompress TQ4 nibble-packed data to full-precision vectors.

    Fused Triton path: unpack + centroid gather + norm scale + cast in a
    single kernel launch.  Does **not** apply the rotation matrix --
    output remains in rotated space.

    Args:
        packed: ``(N, H, D//2)`` uint8 -- nibble-packed centroid indices.
        norms: ``(N, H, 1)`` fp32 -- per-vector norms.
        centroids: ``(C,)`` fp32 -- centroid table (C=16 for TQ4).
        dtype: Output dtype (default: ``torch.float16``).

    Returns:
        Tensor of shape ``(N, H, D)`` in ``dtype``, still in rotated
        space (caller must apply post-rotation if needed).
    """
    N, H, half_D = packed.shape
    D = half_D * 2
    M = N * H

    # CPU fallback: PyTorch path (Triton requires CUDA)
    if not packed.is_cuda:
        return _tq4_decompress_cpu(packed, norms, centroids, dtype)

    # Flatten to 2D for the kernel
    packed_flat = packed.reshape(M, half_D).contiguous()
    norms_flat = norms.reshape(M).contiguous()

    out = torch.empty(M, D, dtype=dtype, device=packed.device)

    grid = (M,)
    _tq4_decompress_kernel[grid](
        packed_flat,
        norms_flat,
        centroids,
        out,
        M,
        HALF_D=half_D,  # ty: ignore[invalid-argument-type]
    )

    return out.reshape(N, H, D)