vLLM Internals

Attention Layer & Backends

Attention is the hottest path in the model. vLLM separates the logical attention layer (what it computes) from the backend (how it computes it). The backend is chosen based on hardware and phase (prefill vs. decode).

References: FlashAttention | FlashAttention-2 | FlashInfer

Attention Basics

Multi-Head Attention with KV Cache

Attn(Q, K, V) = softmax( Q·Kᵀ / √dₖ ) · V

In standard attention, Q, K, V all come from the current input. In autoregressive decoding, only Q is new (the current token); K and V come from the KV cache (all prior tokens).

# Attention layer in vLLM model (e.g., LlamaAttention)
class LlamaAttention(nn.Module):
    def forward(self, hidden_states, kv_cache, attn_metadata):
        # Project to Q, K, V
        q = self.q_proj(hidden_states)  # [seq_len, num_heads * head_dim]
        k = self.k_proj(hidden_states)  # [seq_len, num_kv_heads * head_dim]
        v = self.v_proj(hidden_states)

        # Apply rotary positional embeddings to Q, K
        q, k = self.rotary_emb(positions, q, k)

        # Paged attention — reads from and writes to kv_cache blocks
        output = self.attn(q, k, v, kv_cache, attn_metadata)

        return self.o_proj(output)   # [seq_len, hidden_dim]

Grouped Query Attention (GQA)

Modern models (Llama 3, Gemma, Mistral) use GQA: fewer KV heads than Q heads. For example, Llama-3-8B has 32 Q heads but only 8 KV heads. This reduces KV cache size by 4× while maintaining most model quality.

Attention Class — Logical Layer

Source: vllm/model_executor/layers/attention/attention.py

The Attention class is platform-agnostic. It delegates to a backend via the AttentionBackend interface:

class Attention(nn.Module, AttentionLayerBase):
    def __init__(self, num_heads, head_size, scale, ...):
        self.backend = get_attn_backend(
            num_heads, head_size, ..., vllm_config
        )
        self.impl = self.backend.get_impl_cls()()

    def forward(
        self,
        query: torch.Tensor,     # [num_tokens, num_heads * head_dim]
        key: torch.Tensor,       # [num_tokens, num_kv_heads * head_dim]
        value: torch.Tensor,
        kv_cache: torch.Tensor,  # the physical GPU cache tensor
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        # Write new K, V into the kv_cache at positions from attn_metadata
        # Then compute attention over all cached positions
        return self.impl.forward(query, key, value, kv_cache, attn_metadata)

Attention Backends

Backend Selection

Source: vllm/v1/attention/backend.py

BackendBest ForNotes
FlashInferBackend NVIDIA A100/H100, decode Fastest decode via paged KV; supports ragged batches natively
FlashAttentionBackend NVIDIA, prefill Optimal for long-context prefill; uses tiling to stay in SRAM
XFormersBackend Older NVIDIA, AMD Fallback; uses xformers memory-efficient attention
TritonBackend AMD ROCm Pure Triton kernel; portable across GPU architectures
TRTLLMBackend NVIDIA, TensorRT TensorRT-LLM kernels for highest throughput on Hopper

vLLM can use different backends for prefill and decode within the same model. FlashAttention for prefill (better for long sequences) and FlashInfer for decode (better for short queries over long KV).

FlashInfer — Paged Decode

FlashInfer is purpose-built for paged KV cache inference. Its key advantage over standard FlashAttention is native support for non-contiguous (paged) K/V memory via block tables.

# FlashInfer decode kernel (conceptual):
# For each query token q_i (one per decoding request):
#   For each logical KV block b_j in block_table[i]:
#     phys_block = block_table[i][b_j]
#     k_slice = kv_cache[phys_block, :, :]   # tokens in this block
#     v_slice = kv_cache[phys_block, :, :]
#     partial_attn += softmax(q_i · k_slice^T) · v_slice
#   output[i] = partial_attn  (flash-style: online softmax)

FlashInfer uses a workspace-based API where the block tables and sequence lengths are pre-registered in a BatchDecodeWithPagedKVCacheWrapper, allowing the kernel to be replayed with low overhead each iteration.

# vLLM's use of FlashInfer
from flashinfer import BatchDecodeWithPagedKVCacheWrapper

wrapper = BatchDecodeWithPagedKVCacheWrapper(workspace_buffer)
wrapper.plan(
    indptr=seq_indptr,
    indices=block_tables_flat,
    last_page_len=last_page_lengths,
    num_qo_heads=num_q_heads,
    num_kv_heads=num_kv_heads,
    head_dim=head_dim,
    page_size=block_size,
)
output = wrapper.run(query, kv_cache)

AttentionMetadata

Per-Iteration Attention State

Source: vllm/v1/attention/backend.py

Constructed by _prepare_inputs() in the model runner. Carries all per-batch metadata needed by the attention kernel:

@dataclass
class AttentionMetadata:
    # Which tokens are in prefill vs. decode phase
    num_prefill_tokens: int
    num_decode_tokens: int

    # Sequence lengths (including cache)
    seq_lens: list[int]           # total context length per seq
    query_lens: list[int]         # how many query tokens per seq

    # Block table: [num_seqs, max_blocks_per_seq]
    # Maps logical block index → physical GPU block ID
    block_tables: torch.Tensor

    # For prefill: causal mask indices
    # For decode: page sizes, indptr arrays for FlashInfer

    # Slot mapping: where to write new K/V values in the cache
    slot_mapping: torch.Tensor    # [num_tokens] → flat cache index

Slot Mapping

When new tokens arrive (prefill or decode), their K and V values must be written into the physical KV cache at the right position. The slot mapping translates each input token's logical position to a flat index into the physical cache tensor:

# Example: 2 blocks of size 4, request at position 6
# logical position 6 → block 1 (positions 4-7) → offset 2
# physical block ID = block_table[req][1] = 7
# slot = 7 * block_size + 2 = 7 * 4 + 2 = 30
slot_mapping[token_idx] = 30  # write K/V here in the flat cache

Prefill vs. Decode Paths

The attention kernel behaves differently depending on the phase:

Prefill

Q, K, V all from the prompt. Causal mask (can't attend to future tokens). Compute-bound. Uses FlashAttention (tiled SRAM computation). Sequence length can be thousands of tokens.

Decode

Q is one new token per sequence. K, V come from the cache. Bandwidth-bound (need to read all cached K/V). Uses FlashInfer (paged gather). query_len=1 per sequence.

# In the model runner, the batch contains BOTH:
# - Prefill tokens (contiguous block at the front of input_ids)
# - Decode tokens (one token per decoding request, at the end)
#
# input_ids = [pref0, pref1, pref2, pref3, dec0, dec1]
#                  ← prefill (4 tok) →  ← decode (2 tok) →
#
# Attention handles them separately:
if num_prefill_tokens > 0:
    prefill_output = flash_attn(q_prefill, k_prefill, v_prefill, ...)
if num_decode_tokens > 0:
    decode_output = flashinfer_paged_decode(q_decode, kv_cache, block_tables, ...)