flash_attention
turboquant_vllm.triton.flash_attention ¶
Triton Flash Attention v2 -- forward-only kernel with GQA support.
Phase 1 of the fused TQ4 Flash Attention roadmap (P5). This vanilla kernel matches SDPA output and serves as the scaffold for injecting TQ4 decompression at K/V tile load points in Phase 2.
Supports
- Grouped-Query Attention (GQA) with arbitrary Q/KV head ratios
- Causal and non-causal modes
- Optional additive attention mask (HF-compatible)
- fp32 online softmax accumulation for numerical stability
- RTX 4090 (SM89) and AMD ROCm via Triton HIP backend
Algorithm
Implements the online softmax from FlashAttention-2 (Dao 2023).
Three fp32 state variables per query row -- running max m_i,
running softmax denominator l_i, and output accumulator acc
-- are maintained across K/V tile iterations. The correction factor
alpha = exp2(m_old - m_new) rescales prior accumulated work when
the running maximum increases. This is mathematically exact, not
approximate.
Attributes:
| Name | Type | Description |
|---|---|---|
triton_flash_attention |
Tensor
|
Python wrapper that launches the Triton kernel with autotuned block sizes. |
Examples:
from turboquant_vllm.triton.flash_attention import triton_flash_attention
out = triton_flash_attention(q, k, v) # non-causal
out = triton_flash_attention(q, k, v, is_causal=True) # prefill
See Also
:mod:turboquant_vllm.triton.attention_interface:
HuggingFace AttentionInterface registration.
Functions¶
triton_flash_attention ¶
triton_flash_attention(
q: Tensor,
k: Tensor,
v: Tensor,
sm_scale: float | None = None,
is_causal: bool = False,
attention_mask: Tensor | None = None,
) -> Tensor
Compute scaled dot-product attention using Triton Flash Attention.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
q
|
Tensor
|
Query tensor |
required |
k
|
Tensor
|
Key tensor |
required |
v
|
Tensor
|
Value tensor |
required |
sm_scale
|
float | None
|
Softmax scale factor. Defaults to |
None
|
is_causal
|
bool
|
Apply causal masking. Only valid when |
False
|
attention_mask
|
Tensor | None
|
Optional additive mask
|
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Attention output |
Source code in src/turboquant_vllm/triton/flash_attention.py
232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 | |