Skip to content

tq4_compress

turboquant_vllm.triton.tq4_compress

Fused Triton kernel for TQ4 compression (norm + rotate + quantize + pack).

Phase 3c.9: Replaces the multi-op PyTorch compress path with a single fused kernel. The rotation matrix is pre-split into even/odd column halves so the kernel writes packed nibble output directly without a separate interleave step. Non-power-of-two HEAD_DIM (e.g., 96) is supported via padded tl.arange + masking.

Experiment 015 (post-3c.8) showed compress accounts for 53% of decode time (~0.149ms for K+V at 1 token). The PyTorch path launches 6+ CUDA kernels (norm, divide, matmul, bucketize, clamp, pack). Fusing into one Triton launch eliminates kernel-launch overhead.

Attributes:

Name Type Description
tq4_compress tuple[Tensor, Tensor]

Python wrapper that launches the fused kernel.

Examples:

from turboquant_vllm.triton.tq4_compress import tq4_compress

packed, norms = tq4_compress(
    x,
    rotation_T_even,
    rotation_T_odd,
    boundaries,
)
# packed: (N, H, D//2) uint8, norms: (N, H, 1) fp32
See Also

:mod:turboquant_vllm.triton.tq4_decompress: Phase 3c.8 fused decompress. :mod:turboquant_vllm.vllm.tq4_backend: vLLM backend that calls this kernel.

Functions

tq4_compress

tq4_compress(
    x: Tensor,
    rotation_T_even: Tensor,
    rotation_T_odd: Tensor,
    boundaries: Tensor,
    out: tuple[Tensor, Tensor] | None = None,
) -> tuple[Tensor, Tensor]

Compress vectors to TQ4 nibble-packed format.

Fused Triton path: norm + normalize + tiled rotation + bucketize + nibble-pack in a single kernel launch. Non-power-of-two head dimensions (e.g., 96) are supported via padded tile loads and boundary masking inside the kernel. Non-pow2 dims incur ~5-15 % throughput penalty due to wasted lanes in padded tiles.

Parameters:

Name Type Description Default
x Tensor

(N, H, D) fp16/bf16 input vectors.

required
rotation_T_even Tensor

(D, D//2) fp32 -- even columns of rotation.T, pre-split for contiguous loads.

required
rotation_T_odd Tensor

(D, D//2) fp32 -- odd columns.

required
boundaries Tensor

(N_BOUND,) fp32 quantization boundaries.

required
out tuple[Tensor, Tensor] | None

Optional pre-allocated (packed, norms) buffers. When provided, results are written into these tensors and the same objects are returned. Follows PyTorch out convention.

None

Returns:

Type Description
Tensor

Tuple of (packed, norms) where packed is (N, H, D//2)

Tensor

uint8 and norms is (N, H, 1) fp32.

Source code in src/turboquant_vllm/triton/tq4_compress.py
def tq4_compress(
    x: torch.Tensor,
    rotation_T_even: torch.Tensor,
    rotation_T_odd: torch.Tensor,
    boundaries: torch.Tensor,
    out: tuple[torch.Tensor, torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compress vectors to TQ4 nibble-packed format.

    Fused Triton path: norm + normalize + tiled rotation + bucketize +
    nibble-pack in a single kernel launch. Non-power-of-two head
    dimensions (e.g., 96) are supported via padded tile loads and
    boundary masking inside the kernel.  Non-pow2 dims incur ~5-15 %
    throughput penalty due to wasted lanes in padded tiles.

    Args:
        x: ``(N, H, D)`` fp16/bf16 input vectors.
        rotation_T_even: ``(D, D//2)`` fp32 -- even columns of
            ``rotation.T``, pre-split for contiguous loads.
        rotation_T_odd: ``(D, D//2)`` fp32 -- odd columns.
        boundaries: ``(N_BOUND,)`` fp32 quantization boundaries.
        out: Optional pre-allocated ``(packed, norms)`` buffers.  When
            provided, results are written into these tensors and the same
            objects are returned.  Follows PyTorch ``out`` convention.

    Returns:
        Tuple of ``(packed, norms)`` where packed is ``(N, H, D//2)``
        uint8 and norms is ``(N, H, 1)`` fp32.
    """
    N, H, D = x.shape
    assert D % 2 == 0, f"HEAD_DIM must be even, got {D}"
    HALF_D = D // 2
    M = N * H

    if not x.is_cuda:
        return _tq4_compress_cpu(x, rotation_T_even, rotation_T_odd, boundaries, out)

    x_flat = x.reshape(M, D).contiguous()
    if out is not None:
        packed, norms = out
    else:
        packed = torch.empty(M, HALF_D, dtype=torch.uint8, device=x.device)
        norms = torch.empty(M, dtype=torch.float32, device=x.device)

    N_BOUND = boundaries.shape[0]
    BLOCK_K = min(32, D)
    D_PAD = _next_pow2(D)
    HALF_D_PAD = _next_pow2(HALF_D)
    assert HALF_D_PAD * 2 == D_PAD, (
        f"Padding invariant violated: 2*HALF_D_PAD ({2 * HALF_D_PAD}) "
        f"!= D_PAD ({D_PAD}) — nibble pack/unpack requires this"
    )

    grid = (M,)
    _tq4_compress_kernel[grid](
        x_flat,
        rotation_T_even,
        rotation_T_odd,
        boundaries,
        packed,
        norms,
        M,
        D=D,  # ty: ignore[invalid-argument-type]
        HALF_D=HALF_D,  # ty: ignore[invalid-argument-type]
        N_BOUND=N_BOUND,  # ty: ignore[invalid-argument-type]
        BLOCK_K=BLOCK_K,  # ty: ignore[invalid-argument-type]
        D_PAD=D_PAD,  # ty: ignore[invalid-argument-type]
        HALF_D_PAD=HALF_D_PAD,  # ty: ignore[invalid-argument-type]
    )

    if out is not None:
        return packed, norms
    return packed.reshape(N, H, HALF_D), norms.reshape(N, H, 1)