LLM Inference Optimization Engine
CUDA · Python · C++ · FlashAttention · Quantization
Overview
A from-scratch inference engine for transformer-based language models, built to explore the performance engineering techniques that underpin production systems like vLLM and TensorRT-LLM. The goal was to understand — and measure — each optimization in isolation before combining them.
Technical Implementation
KV Cache & Paged Attention
Autoregressive decoding is memory-bound: the key/value tensors from prior tokens must be
re-read on every step. A contiguous KV cache eliminates recomputation but wastes GPU memory
due to fragmentation. Paged attention (as in vLLM) allocates KV blocks in fixed pages,
allowing non-contiguous physical storage and near-zero internal fragmentation.
# Simplified paged-attention KV lookup def gather_kv(block_table, cache, seq_len, block_size): blocks = block_table[:seq_len // block_size + 1] pages = cache[blocks] # gather physical pages return pages.view(-1, HEAD_DIM)[:seq_len]
FlashAttention
Standard attention materializes the full N×N score matrix in HBM, which becomes the
bottleneck at long context. FlashAttention tiles the computation into SRAM-resident blocks,
computing a numerically stable softmax with online updates — reducing HBM reads/writes from
O(N²) to O(N).
// Tiled SRAM attention kernel (pseudocode) for (int j = 0; j < Tc; ++j) { load_kv_tile(K, V, j); // K,V tile → SRAM S_ij = Q_i @ K_j.T; // score tile m_new = max(m_i, rowmax(S_ij)); // online max P_ij = exp(S_ij - m_new); O_i = rescale(O_i, m_i, m_new) + P_ij @ V_j; m_i = m_new; }
Speculative Decoding
A small draft model proposes K tokens in parallel; the target model verifies them in a
single forward pass. Accepted tokens are kept; the first rejection rolls back. When the
draft acceptance rate is high, this yields near-K× throughput improvement with
identical output distribution.
Low-Precision Matrix Multiply
INT8 weight quantization (with per-channel scaling) reduces memory bandwidth by 2×
relative to FP16, the dominant bottleneck during decode. FP8 and W4A16 schemes were also
benchmarked using custom CUDA extensions and Triton kernels.
# W8A8 linear layer (simplified) def linear_int8(x, W_q, scale_x, scale_w): x_q = quantize_per_tensor(x, scale_x) # → INT8 out_q = torch._int_mm(x_q, W_q.T) // INT32 accum return out_q.float() * (scale_x * scale_w)
Results
- Paged attention reduced peak KV memory by ~35% vs. contiguous cache at batch size 32.
- FlashAttention kernel achieved ~3.1× lower attention latency vs. eager PyTorch at seq_len=2048.
- INT8 weight quantization delivered ~1.8× decode throughput vs. BF16 baseline on A100.
- Speculative decoding (3-token draft) achieved ~2.4× tokens/sec on greedy tasks.