tq4_decompress
turboquant_vllm.triton.tq4_decompress ¶
Fused Triton kernel for TQ4 cache decompression (no rotation).
Phase 3c.8: Replaces the multi-op PyTorch decompress path with a single
fused kernel that performs nibble unpack -> centroid gather -> norm scale
-> dtype cast in one launch. The rotation is not applied here -- the
caller pre-rotates Q by Pi^T and post-rotates the attention output
by Pi, saving O(cache_len) matmuls per decode step. Non-power-of-two
HEAD_DIM (e.g., 96) is supported via padded tl.arange + masking.
Experiment 015 showed that at cache_len=4096 on RTX 4090, decompress accounts for 68% of decode time. The rotation matmul (128x128) is the dominant cost within that 68%. By moving rotation to Q/output (O(1) per decode step), the kernel only needs elementwise + gather ops.
Attributes:
| Name | Type | Description |
|---|---|---|
tq4_decompress |
Tensor
|
Python wrapper that launches the fused kernel. |
Examples:
from turboquant_vllm.triton.tq4_decompress import tq4_decompress
# packed: (N, H, D//2) uint8, norms: (N, H, 1) fp32
out = tq4_decompress(packed, norms, centroids, dtype=torch.float16)
# out: (N, H, D) fp16 -- still in rotated space (no Pi applied)
See Also
:mod:turboquant_vllm.triton.flash_attention_tq4: Phase 2 fused FA+K kernel.
:mod:turboquant_vllm.vllm.tq4_backend: vLLM backend that calls this kernel.
Functions¶
tq4_decompress ¶
tq4_decompress(
packed: Tensor,
norms: Tensor,
centroids: Tensor,
dtype: dtype = float16,
out: Tensor | None = None,
) -> Tensor
Decompress TQ4 nibble-packed data to full-precision vectors.
Fused Triton path: unpack + centroid gather + norm scale + cast in a single kernel launch. Does not apply the rotation matrix -- output remains in rotated space. Non-power-of-two head dimensions (e.g., 96) are supported via padded tile loads and boundary masking. Non-pow2 dims incur ~5-15 % throughput penalty due to wasted lanes in padded tiles.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
packed
|
Tensor
|
|
required |
norms
|
Tensor
|
|
required |
centroids
|
Tensor
|
|
required |
dtype
|
dtype
|
Output dtype (default: |
float16
|
out
|
Tensor | None
|
Optional pre-allocated |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Tensor of shape |
Tensor
|
space (caller must apply post-rotation if needed). |