Skip to content

tq4_compress

turboquant_vllm.triton.tq4_compress

Fused Triton kernel for TQ4 compression (norm + rotate + quantize + pack).

Phase 3c.9: Replaces the multi-op PyTorch compress path with a single fused kernel. The rotation matrix is pre-split into even/odd column halves so the kernel writes packed nibble output directly without a separate interleave step.

Experiment 015 (post-3c.8) showed compress accounts for 53% of decode time (~0.149ms for K+V at 1 token). The PyTorch path launches 6+ CUDA kernels (norm, divide, matmul, bucketize, clamp, pack). Fusing into one Triton launch eliminates kernel-launch overhead.

Attributes:

Name Type Description
tq4_compress tuple[Tensor, Tensor]

Python wrapper that launches the fused kernel.

Examples:

from turboquant_vllm.triton.tq4_compress import tq4_compress

packed, norms = tq4_compress(
    x,
    rotation_T_even,
    rotation_T_odd,
    boundaries,
)
# packed: (N, H, D//2) uint8, norms: (N, H, 1) fp32
See Also

:mod:turboquant_vllm.triton.tq4_decompress: Phase 3c.8 fused decompress. :mod:turboquant_vllm.vllm.tq4_backend: vLLM backend that calls this kernel.

Functions

tq4_compress

tq4_compress(
    x: Tensor, rotation_T_even: Tensor, rotation_T_odd: Tensor, boundaries: Tensor
) -> tuple[Tensor, Tensor]

Compress vectors to TQ4 nibble-packed format.

Fused Triton path: norm + normalize + tiled rotation + bucketize + nibble-pack in a single kernel launch.

Parameters:

Name Type Description Default
x Tensor

(N, H, D) fp16/bf16 input vectors.

required
rotation_T_even Tensor

(D, D//2) fp32 -- even columns of rotation.T, pre-split for contiguous loads.

required
rotation_T_odd Tensor

(D, D//2) fp32 -- odd columns.

required
boundaries Tensor

(N_BOUND,) fp32 quantization boundaries.

required

Returns:

Type Description
Tensor

Tuple of (packed, norms) where packed is (N, H, D//2)

Tensor

uint8 and norms is (N, H, 1) fp32.

Source code in src/turboquant_vllm/triton/tq4_compress.py
def tq4_compress(
    x: torch.Tensor,
    rotation_T_even: torch.Tensor,
    rotation_T_odd: torch.Tensor,
    boundaries: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compress vectors to TQ4 nibble-packed format.

    Fused Triton path: norm + normalize + tiled rotation + bucketize +
    nibble-pack in a single kernel launch.

    Args:
        x: ``(N, H, D)`` fp16/bf16 input vectors.
        rotation_T_even: ``(D, D//2)`` fp32 -- even columns of
            ``rotation.T``, pre-split for contiguous loads.
        rotation_T_odd: ``(D, D//2)`` fp32 -- odd columns.
        boundaries: ``(N_BOUND,)`` fp32 quantization boundaries.

    Returns:
        Tuple of ``(packed, norms)`` where packed is ``(N, H, D//2)``
        uint8 and norms is ``(N, H, 1)`` fp32.
    """
    N, H, D = x.shape
    HALF_D = D // 2
    M = N * H

    if not x.is_cuda:
        return _tq4_compress_cpu(x, rotation_T_even, rotation_T_odd, boundaries)

    x_flat = x.reshape(M, D).contiguous()
    packed = torch.empty(M, HALF_D, dtype=torch.uint8, device=x.device)
    norms = torch.empty(M, dtype=torch.float32, device=x.device)

    N_BOUND = boundaries.shape[0]
    BLOCK_K = min(32, D)

    grid = (M,)
    _tq4_compress_kernel[grid](
        x_flat,
        rotation_T_even,
        rotation_T_odd,
        boundaries,
        packed,
        norms,
        M,
        D=D,  # ty: ignore[invalid-argument-type]
        HALF_D=HALF_D,  # ty: ignore[invalid-argument-type]
        N_BOUND=N_BOUND,  # ty: ignore[invalid-argument-type]
        BLOCK_K=BLOCK_K,  # ty: ignore[invalid-argument-type]
    )

    return packed.reshape(N, H, HALF_D), norms.reshape(N, H, 1)