fused_paged_tq4_attention
turboquant_vllm.triton.fused_paged_tq4_attention ¶
Fused paged TQ4 decode attention -- decompresses directly from page table.
Phase 3a of the D9 kernel roadmap. This kernel reads TQ4-compressed blocks directly from vLLM's paged block table, decompresses in SRAM (nibble unpack -> centroid gather -> norm scale), and computes FP16 Q@K^T with online softmax in a single fused pass. No HBM writes of decompressed cache -- HBM traffic drops from 1,160 to 136 bytes/token (8.5x reduction).
The kernel operates entirely in rotated space. The caller
pre-rotates Q by Pi^T and post-rotates the output by Pi.
Decompression does NOT apply rotation (matching tq4_decompress.py).
Scope: FP16/BF16 Q decode path only (USE_INT8_QK=False). INT8 path
is Story 6.4. Placeholder parameters are included for forward
compatibility but compiled out by the constexpr switch.
Autotune: 8 configs (BLOCK_N in {32, 64} x stages {2,3} x warps {4,8}). BLOCK_N=16 dropped after Experiment 020 profiling showed it consistently slowest across 1K-32K context on RTX 4090.
Attributes:
| Name | Type | Description |
|---|---|---|
fused_paged_tq4_decode |
Tensor
|
Python wrapper that pre-rotates Q, launches the fused paged kernel, and post-rotates the output. |
Examples:
from turboquant_vllm.triton.fused_paged_tq4_attention import (
fused_paged_tq4_decode,
)
out = fused_paged_tq4_decode(
q,
kv_cache,
block_table,
seq_lens,
centroids,
rotation,
num_kv_heads=4,
head_dim=128,
block_size=16,
)
See Also
:mod:turboquant_vllm.triton.flash_attention_tq4_kv: Contiguous
(non-paged) reference kernel -- correctness baseline.
:mod:turboquant_vllm.triton.tq4_decompress: Standalone decompress.
Functions¶
fused_paged_tq4_decode ¶
fused_paged_tq4_decode(
q: Tensor,
kv_cache: Tensor,
block_table: Tensor,
seq_lens: Tensor,
centroids: Tensor,
rotation: Tensor,
num_kv_heads: int,
head_dim: int,
block_size: int,
sm_scale: float | None = None,
out: Tensor | None = None,
) -> Tensor
Fused paged TQ4 decode attention.
Pre-rotates Q by rotation^T, launches the fused paged kernel
that decompresses TQ4 blocks in-tile from the page table, then
post-rotates the output by rotation to return to original space.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
q
|
Tensor
|
Query |
required |
kv_cache
|
Tensor
|
Packed paged cache |
required |
block_table
|
Tensor
|
Page table |
required |
seq_lens
|
Tensor
|
Sequence lengths |
required |
centroids
|
Tensor
|
TQ4 codebook |
required |
rotation
|
Tensor
|
Orthogonal rotation |
required |
num_kv_heads
|
int
|
Number of KV heads. |
required |
head_dim
|
int
|
Head dimension (e.g. 128). |
required |
block_size
|
int
|
vLLM page size (tokens per block). |
required |
sm_scale
|
float | None
|
Softmax scale. Defaults to |
None
|
out
|
Tensor | None
|
Optional pre-allocated output |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Attention output |
Tensor
|
When |
Tensor
|
in-place. |
Note
INT8 placeholder parameters (Q_scale, QJL_S, QJL_signs,
QJL_residual_norms) should be passed as None/zeros when
USE_INT8_QK=False (the default for Phase 3a).
Source code in src/turboquant_vllm/triton/fused_paged_tq4_attention.py
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 | |