← All Projects

LLM Inference Optimization Engine

CUDA · Python · C++ · FlashAttention · Quantization

Overview

A from-scratch inference engine for transformer-based language models, built to explore the performance engineering techniques that underpin production systems like vLLM and TensorRT-LLM. The goal was to understand — and measure — each optimization in isolation before combining them.

Technical Implementation

KV Cache & Paged Attention
Autoregressive decoding is memory-bound: the key/value tensors from prior tokens must be re-read on every step. A contiguous KV cache eliminates recomputation but wastes GPU memory due to fragmentation. Paged attention (as in vLLM) allocates KV blocks in fixed pages, allowing non-contiguous physical storage and near-zero internal fragmentation.

# Simplified paged-attention KV lookup
def gather_kv(block_table, cache, seq_len, block_size):
    blocks = block_table[:seq_len // block_size + 1]
    pages  = cache[blocks]          # gather physical pages
    return pages.view(-1, HEAD_DIM)[:seq_len]

FlashAttention
Standard attention materializes the full N×N score matrix in HBM, which becomes the bottleneck at long context. FlashAttention tiles the computation into SRAM-resident blocks, computing a numerically stable softmax with online updates — reducing HBM reads/writes from O(N²) to O(N).

// Tiled SRAM attention kernel (pseudocode)
for (int j = 0; j < Tc; ++j) {
    load_kv_tile(K, V, j);           // K,V tile → SRAM
    S_ij = Q_i @ K_j.T;             // score tile
    m_new = max(m_i, rowmax(S_ij));  // online max
    P_ij  = exp(S_ij - m_new);
    O_i   = rescale(O_i, m_i, m_new) + P_ij @ V_j;
    m_i   = m_new;
}

Speculative Decoding
A small draft model proposes K tokens in parallel; the target model verifies them in a single forward pass. Accepted tokens are kept; the first rejection rolls back. When the draft acceptance rate is high, this yields near-K× throughput improvement with identical output distribution.

Low-Precision Matrix Multiply
INT8 weight quantization (with per-channel scaling) reduces memory bandwidth by 2× relative to FP16, the dominant bottleneck during decode. FP8 and W4A16 schemes were also benchmarked using custom CUDA extensions and Triton kernels.

# W8A8 linear layer (simplified)
def linear_int8(x, W_q, scale_x, scale_w):
    x_q   = quantize_per_tensor(x,   scale_x)   # → INT8
    out_q = torch._int_mm(x_q, W_q.T)           // INT32 accum
    return out_q.float() * (scale_x * scale_w)

Results

  • Paged attention reduced peak KV memory by ~35% vs. contiguous cache at batch size 32.
  • FlashAttention kernel achieved ~3.1× lower attention latency vs. eager PyTorch at seq_len=2048.
  • INT8 weight quantization delivered ~1.8× decode throughput vs. BF16 baseline on A100.
  • Speculative decoding (3-token draft) achieved ~2.4× tokens/sec on greedy tasks.

Links