Skip to content

molmo2_integration

turboquant_vllm.triton.molmo2_integration

Fused TurboQuant attention integration for Molmo2 models.

Patches Molmo2 attention layers to compute Q @ K^T directly from nibble-packed 4-bit compressed keys using the fused Triton kernel. Keys are never materialized as full fp16 tensors during attention.

Values are stored uncompressed in fp16 (the softmax @ V path benefits less from compression and doesn't need a fused kernel).

Attributes:

Name Type Description
FusedTurboQuantRunner

High-level runner that patches a Molmo2 model, generates text, and cleans up.

install_fused_attention CompressedKVStore

Low-level function to patch attention layers.

Examples:

runner = FusedTurboQuantRunner(model, processor, bits=4)
text, stats = runner.generate(
    prompt="Describe this scene.",
    video_path="/path/to/video.mp4",
    max_new_tokens=256,
)
See Also

:mod:turboquant_vllm.triton.fused_qk_attention: The Triton kernel. :mod:turboquant_vllm.kv_cache: Unfused CompressedDynamicCache.

Classes

CompressedKVStore

CompressedKVStore(quantizer: TurboQuantMSE)

Bases: DynamicCache

KV store with compressed keys and standard values.

Keys are compressed into nibble-packed uint8 indices + fp32 norms in side storage for the fused Triton kernel. Values and all DynamicLayer bookkeeping are managed by the base DynamicCache via the overridden update() method.

This cache is passed as past_key_values to model.generate().

Attributes:

Name Type Description
quantizer TurboQuantMSE

The TQ4 quantizer instance.

rotation_T Tensor

Transposed rotation matrix for query pre-rotation, shape (head_dim, head_dim).

centroids Tensor

Lloyd-Max centroid values, shape (n_levels,).

Examples:

store = CompressedKVStore(quantizer=tq)
store.compress_and_store_key(key_states, layer_idx=0)

Initialize the compressed KV store.

Parameters:

Name Type Description Default
quantizer TurboQuantMSE

TurboQuantMSE instance for key compression.

required
Source code in src/turboquant_vllm/triton/molmo2_integration.py
def __init__(self, quantizer: TurboQuantMSE) -> None:
    """Initialize the compressed KV store.

    Args:
        quantizer: TurboQuantMSE instance for key compression.
    """
    super().__init__()
    self.quantizer = quantizer
    self.rotation_T = quantizer.rotation.T.contiguous()
    self.centroids = quantizer.codebook.centroids.contiguous()

    self._packed_indices: list[torch.Tensor | None] = []
    self._norms: list[torch.Tensor | None] = []
Functions
update
update(
    key_states: Tensor,
    value_states: Tensor,
    layer_idx: int,
    cache_kwargs: dict[str, Any] | None = None,
) -> tuple[Tensor, Tensor]

Compress keys on write, store values normally via DynamicCache.

Overrides DynamicCache.update() to intercept key storage. Keys are nibble-packed into compressed side storage for the fused Triton kernel. Values and all DynamicLayer bookkeeping (seq_length, layer creation, offloading) are handled by the base class, keeping model.generate() happy.

Parameters:

Name Type Description Default
key_states Tensor

Key tensor, shape (batch, n_kv_heads, seq_len, head_dim).

required
value_states Tensor

Value tensor, same shape as key_states.

required
layer_idx int

Transformer layer index.

required
cache_kwargs dict[str, Any] | None

Additional cache arguments (passed to base).

None

Returns:

Type Description
Tensor

Tuple of (full_keys, full_values) from the base class.

Tensor

The returned keys are uncompressed (for compatibility);

tuple[Tensor, Tensor]

the fused kernel reads from compressed side storage instead.

Source code in src/turboquant_vllm/triton/molmo2_integration.py
def update(
    self,
    key_states: torch.Tensor,
    value_states: torch.Tensor,
    layer_idx: int,
    cache_kwargs: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compress keys on write, store values normally via DynamicCache.

    Overrides ``DynamicCache.update()`` to intercept key storage.
    Keys are nibble-packed into compressed side storage for the
    fused Triton kernel. Values and all ``DynamicLayer`` bookkeeping
    (seq_length, layer creation, offloading) are handled by the
    base class, keeping ``model.generate()`` happy.

    Args:
        key_states: Key tensor, shape
            ``(batch, n_kv_heads, seq_len, head_dim)``.
        value_states: Value tensor, same shape as key_states.
        layer_idx: Transformer layer index.
        cache_kwargs: Additional cache arguments (passed to base).

    Returns:
        Tuple of (full_keys, full_values) from the base class.
        The returned keys are uncompressed (for compatibility);
        the fused kernel reads from compressed side storage instead.
    """
    # 1. Compress keys into side storage
    indices, norms = self.quantizer.quantize(key_states.float())
    indices = indices.to(torch.uint8)
    packed = (indices[..., 0::2] << 4) | indices[..., 1::2]
    norms = norms.float().squeeze(-1)

    while len(self._packed_indices) <= layer_idx:
        self._packed_indices.append(None)
        self._norms.append(None)

    if self._packed_indices[layer_idx] is None:
        self._packed_indices[layer_idx] = packed
        self._norms[layer_idx] = norms
    else:
        self._packed_indices[layer_idx] = torch.cat(
            [self._packed_indices[layer_idx], packed], dim=2
        )
        self._norms[layer_idx] = torch.cat([self._norms[layer_idx], norms], dim=2)

    # 2. Let DynamicCache handle values + DynamicLayer bookkeeping.
    #    Keys are passed through for seq_length tracking.
    return super().update(key_states, value_states, layer_idx, cache_kwargs)
get_compressed_key
get_compressed_key(layer_idx: int) -> tuple[Tensor, Tensor]

Return compressed key data for a layer.

Parameters:

Name Type Description Default
layer_idx int

Transformer layer index.

required

Returns:

Type Description
tuple[Tensor, Tensor]

Tuple of (packed_indices, norms).

Source code in src/turboquant_vllm/triton/molmo2_integration.py
def get_compressed_key(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
    """Return compressed key data for a layer.

    Args:
        layer_idx: Transformer layer index.

    Returns:
        Tuple of (packed_indices, norms).
    """
    return self._packed_indices[layer_idx], self._norms[layer_idx]

FusedTurboQuantRunner

FusedTurboQuantRunner(model: Module, processor: Any, bits: int = 4, *, seed: int = 42)

High-level runner for fused TurboQuant inference on Molmo2.

Patches the model, runs inference, and cleans up. Handles both text-only and video inputs.

Attributes:

Name Type Description
model Module

The Molmo2 model.

processor Any

The Molmo2 processor.

bits int

Quantization bit width.

Examples:

runner = FusedTurboQuantRunner(model, processor, bits=4)
text, stats = runner.generate("Describe this.", max_new_tokens=256)
print(text)

Initialize the runner.

Parameters:

Name Type Description Default
model Module

A loaded Molmo2 model.

required
processor Any

The corresponding Molmo2 processor.

required
bits int

Quantization bits (default 4 for nibble packing).

4
seed int

Random seed for reproducibility.

42
Source code in src/turboquant_vllm/triton/molmo2_integration.py
def __init__(
    self,
    model: nn.Module,
    processor: Any,
    bits: int = 4,
    *,
    seed: int = 42,
) -> None:
    """Initialize the runner.

    Args:
        model: A loaded Molmo2 model.
        processor: The corresponding Molmo2 processor.
        bits: Quantization bits (default 4 for nibble packing).
        seed: Random seed for reproducibility.
    """
    self.model = model
    self.processor = processor
    self.bits = bits
    self.seed = seed
Functions
generate
generate(
    prompt: str, video_path: str | None = None, max_new_tokens: int = 256
) -> tuple[str, dict]

Generate text with fused TurboQuant attention.

Parameters:

Name Type Description Default
prompt str

Text prompt.

required
video_path str | None

Optional path to a video file.

None
max_new_tokens int

Maximum tokens to generate.

256

Returns:

Type Description
tuple[str, dict]

Tuple of (generated_text, stats_dict).

Source code in src/turboquant_vllm/triton/molmo2_integration.py
def generate(
    self,
    prompt: str,
    video_path: str | None = None,
    max_new_tokens: int = 256,
) -> tuple[str, dict]:
    """Generate text with fused TurboQuant attention.

    Args:
        prompt: Text prompt.
        video_path: Optional path to a video file.
        max_new_tokens: Maximum tokens to generate.

    Returns:
        Tuple of (generated_text, stats_dict).
    """
    import time

    # Build input
    content: list[dict] = []
    if video_path:
        content.append({"type": "video", "video": video_path})
    content.append({"type": "text", "text": prompt})

    messages = [{"role": "user", "content": content}]
    inputs = self.processor.apply_chat_template(
        messages,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    )
    inputs = {
        k: v.to(self.model.device) if hasattr(v, "to") else v
        for k, v in inputs.items()
    }
    input_len = inputs["input_ids"].shape[-1]

    # Install fused attention
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()

    store = install_fused_attention(self.model, self.bits, seed=self.seed)

    # Generate
    t0 = time.perf_counter()
    with torch.inference_mode():
        output_ids = self.model.generate(  # ty: ignore[call-non-callable]
            **inputs,
            past_key_values=store,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            use_cache=True,
        )
    elapsed = time.perf_counter() - t0
    vram_peak = torch.cuda.max_memory_allocated() / (1024 * 1024)

    # Decode
    generated_ids = output_ids[0, input_len:]
    text = self.processor.decode(generated_ids, skip_special_tokens=True)

    output_len = len(generated_ids)
    tok_per_sec = output_len / elapsed if elapsed > 0 else 0

    # Uninstall
    uninstall_fused_attention(self.model)

    stats = {
        "input_tokens": input_len,
        "output_tokens": output_len,
        "tokens_per_sec": round(tok_per_sec, 1),
        "elapsed_s": round(elapsed, 2),
        "vram_peak_mib": round(vram_peak, 1),
        "bits": self.bits,
        "kv_seq_len": store.get_seq_length(),
    }

    return text, stats

Functions

install_fused_attention

install_fused_attention(
    model: Module, bits: int = 4, *, seed: int = 42
) -> CompressedKVStore

Patch all Molmo2 text attention layers to use fused TurboQuant.

Parameters:

Name Type Description Default
model Module

A loaded Molmo2 model.

required
bits int

Quantization bits per coordinate (default 4 for nibble packing).

4
seed int

Random seed for reproducibility.

42

Returns:

Type Description
CompressedKVStore

A CompressedKVStore to pass as past_key_values to

CompressedKVStore

model.generate().

Source code in src/turboquant_vllm/triton/molmo2_integration.py
def install_fused_attention(
    model: nn.Module,
    bits: int = 4,
    *,
    seed: int = 42,
) -> CompressedKVStore:
    """Patch all Molmo2 text attention layers to use fused TurboQuant.

    Args:
        model: A loaded Molmo2 model.
        bits: Quantization bits per coordinate (default 4 for nibble packing).
        seed: Random seed for reproducibility.

    Returns:
        A CompressedKVStore to pass as ``past_key_values`` to
        ``model.generate()``.
    """
    # Detect head_dim from model config
    config = model.config
    text_config = getattr(config, "text_config", config)
    head_dim = getattr(text_config, "head_dim", 128)

    # Create quantizer and store
    tq = TurboQuantMSE(head_dim, bits, seed=seed)
    kv_store = CompressedKVStore(tq)

    # Find and patch text attention layers
    patched = 0
    for name, module in model.named_modules():
        if hasattr(module, "att_proj") and hasattr(module, "attn_out"):
            layer_idx = getattr(module, "layer_idx", patched)
            module._original_forward = module.forward
            module.forward = _make_fused_forward(module, kv_store, layer_idx)
            patched += 1

    print(f"  Installed fused TurboQuant TQ{bits} on {patched} attention layers")
    return kv_store

uninstall_fused_attention

uninstall_fused_attention(model: Module) -> None

Restore original attention forwards.

Parameters:

Name Type Description Default
model Module

The patched Molmo2 model.

required
Source code in src/turboquant_vllm/triton/molmo2_integration.py
def uninstall_fused_attention(model: nn.Module) -> None:
    """Restore original attention forwards.

    Args:
        model: The patched Molmo2 model.
    """
    restored = 0
    for _name, module in model.named_modules():
        if hasattr(module, "_original_forward"):
            module.forward = module._original_forward
            del module._original_forward
            restored += 1
    if restored:
        print(f"  Restored {restored} original attention forwards")