tq4_backend
turboquant_vllm.vllm.tq4_backend ¶
TQ4 compressed KV cache attention backend for vLLM.
Phase 3c: Packed TQ4 cache layout with real VRAM savings.
The KV cache is stored as uint8 bytes in a packed TQ4 format (68 bytes
per token per head per K/V = 136 bytes total vs 512 bytes FP16 = 3.76x
compression). Buffer allocation uses a custom TQ4FullAttentionSpec
that overrides page_size_bytes so the block allocator provisions
3.76x more blocks in the same VRAM budget. Each forward() call
decompresses the relevant blocks to FP16 and delegates to Flash Attention.
Implementation phases
3a (done): Passthrough skeleton -- validated plugin wiring. 3b (done): Compress-decompress round-trip in standard FP16 cache. 3c (this): Packed uint8 cache with real VRAM savings. 3d: Production benchmark against vLLM baseline.
Classes¶
TQ4FullAttentionSpec
dataclass
¶
Bases: FullAttentionSpec
KV cache spec with TQ4 packed page size.
Overrides real_page_size_bytes so the block allocator provisions
buffers sized for the packed TQ4 format (3.76x smaller than FP16).
Follows the same pattern as MLAAttentionSpec which overrides
page size for the 656-byte FlashMLA format.
TQ4AttentionBackend ¶
Bases: FlashAttentionBackend
TQ4 compressed KV cache attention backend.
Phase 3c: packed uint8 cache layout with real VRAM savings.
The cache stores nibble-packed TQ4 indices + fp32 norms as raw bytes.
get_kv_cache_shape() returns a 3D (NB, BS, bytes_per_token)
layout matching the packed format.
Functions¶
supports_mm_prefix
classmethod
¶
get_name
staticmethod
¶
get_impl_cls
staticmethod
¶
get_builder_cls
staticmethod
¶
get_kv_cache_shape
staticmethod
¶
get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]
Packed TQ4 cache: (num_blocks, block_size, total_bytes).
The last dimension packs K and V data for all heads as raw bytes:
[K_indices | K_norms | V_indices | V_norms].
Source code in src/turboquant_vllm/vllm/tq4_backend.py
get_kv_cache_stride_order
staticmethod
¶
Raise to trigger identity fallback in reshape.
The inherited FlashAttentionBackend returns a 5-element stride
order for the standard (2, NB, BS, H, D) shape. Our 3D
packed layout (NB, BS, total_bytes) needs identity ordering.
Raising NotImplementedError triggers the fallback in
_reshape_kv_cache_tensors (same pattern as FlashMLA which
does not implement this method at all).
Source code in src/turboquant_vllm/vllm/tq4_backend.py
TQ4AttentionImpl ¶
Bases: FlashAttentionImpl
TQ4 attention: compress -> store -> decompress -> Flash Attention.
Phase 3c: stores packed TQ4 bytes in a uint8 cache for real VRAM
savings. Each forward() call:
- Compresses incoming K/V tokens to TQ4 packed bytes.
- Scatter-writes packed bytes to the uint8 cache via
slot_mapping. - Decompresses the full cache to FP16 for Flash Attention.
- Calls
flash_attn_varlen_funcdirectly with the FP16 data.
Initialize TQ4 attention with compression primitives.
Source code in src/turboquant_vllm/vllm/tq4_backend.py
Functions¶
forward ¶
forward(
layer,
query,
key,
value,
kv_cache,
attn_metadata,
output=None,
output_scale=None,
output_block_scale=None,
)
TQ4 attention: compress -> store -> pre-rotate Q -> decompress -> FA -> post-rotate.
Phase 3c.8: Uses fused Triton decompress (no rotation). The rotation is applied to Q before attention and to the output after, saving O(cache_len) matmuls per decode step.
Source code in src/turboquant_vllm/vllm/tq4_backend.py
395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 | |
Functions¶
register_tq4_backend ¶
Register TQ4 as the CUSTOM attention backend.
In addition to registering the backend class, this monkey-patches
Attention.get_kv_cache_spec so that decoder attention layers
return :class:TQ4FullAttentionSpec (with dtype=torch.uint8
and TQ4-sized pages) instead of the standard FullAttentionSpec.
Called automatically by the vllm.general_plugins entry point,
or manually before starting vLLM::
from turboquant_vllm.vllm import register_tq4_backend
register_tq4_backend()
# then start vLLM with --attention-backend CUSTOM