Overview
pipeline
▼
Codegen receives a single kernel AST (the SINK UOp from the scheduler) and produces a compiled PROGRAM UOp containing device-executable bytes. It does this through a sequence of graph-rewrite passes followed by linearization and rendering.
do_to_program(ast: UOp, renderer: Renderer) -> UOp
│
├─ full_rewrite_to_sink(ast, renderer) # optimize + lower to indexed loops
│
├─ do_linearize(renderer, prg, sink) # toposort UOps for execution
│
├─ do_render(renderer, prg, lin) # generate source string
│ └─ renderer.render(uops) -> str
│
└─ do_compile(renderer, prg, source) # compile to binary
└─ renderer.compiler.compile(src) -> bytes
codegen/__init__.py:176 — do_to_program
full_rewrite_to_sink — the optimization pipeline
codegen/__init__.py:27
▼
This function applies a cascade of PatternMatcher passes, each transforming the UOp AST closer to device-runnable code.
Phase 1 — Early movement ops
pm_mops: Propagate/eliminate RESHAPE, EXPAND, PERMUTE, SHRINK through math opspm_syntactic_sugar: Simplify redundant index/cast chainspm_store_ranges: Attach loop range info to STORE nodes
Phase 2 — Optimization (when optimize=True)
- Load collapse: eliminate redundant memory reads within a kernel
- Range splitting: divide loops for parallelism (local/global split)
- Symbolic simplification: constant-fold index arithmetic
- BEAM search or hand-coded heuristics to pick best tiling/unrolling
- Post-range opts: loop reordering, cache tiling
Phase 3 — Expander
pm_pre_expander: Prepare ops for vectorization/unrollingpm_group_for_reduce: Identify reduction groups for warp/block operationsexpander: Expand complex ops (tensor cores, vector) into constituent UOps
Phase 4 — Local buffer & range concretization
pm_add_buffers_local: Insert DEFINE_LOCAL for shared memory allocationsrangeify_codegen: Finalize RANGE/END pairs with concrete integer bounds
Phase 5 — Devectorize & lower
pm_reduce: Convert REDUCE ops to explicit loop + accumulator patternspm_add_loads: Insert explicit LOAD for each variable use- Devectorize passes: break vector ops into scalar instructions
- Index dtype lowering: replace symbolic shapes with int32/int64
- Op decomposition: rewrite ops not in
renderer.supported_opsto equivalents - Transcendental lowering: replace sin/cos/exp2 with polynomial approximations if needed
do_linearize — instruction ordering
codegen/__init__.py:132
▼
Converts the DAG SINK back into a flat ordered list of UOps (list[UOp]) that a renderer can iterate over sequentially.
1
Topological sort with run-count priority
UOps inside loops execute many times. Scheduling considers multiplicity — ops inside nested loops should come after their loop RANGE nodes.
2
PLACE ordering
Ensures control-flow structure (RANGE/END pairs, IF/ENDIF) is correct and that variable definitions precede uses.
do_render & do_compile
codegen/__init__.py:155–170
▼
R
do_render
Calls
codegen/__init__.py:155 — do_render
renderer.render(uops_list), which walks the linearized list and emits the device-specific source string (C, PTX, WGSL, Metal, LLVM IR, etc.).C
do_compile
Calls
codegen/__init__.py:160 — do_compile
renderer.compiler.compile_cached(source), which invokes the device compiler (nvrtc for CUDA, Metal compiler for Metal, LLVM for CPU…) and returns the compiled binary bytes as a BINARY UOp.Compilation results are cached by source hash in
Compiler.compile_cached(), so identical kernels (same ops, same shapes) are only compiled once per device per process.
Transformation example: a simple elementwise kernel
walkthrough
▼
# Input AST (from scheduler):
SINK(
STORE(BUFFER(out, shape=(1024,)),
ADD(LOAD(BUFFER(a, shape=(1024,))),
LOAD(BUFFER(b, shape=(1024,)))))
)
# After full_rewrite_to_sink:
SINK(
RANGE(i, 0, 1024) ─────────────────────────────┐
STORE(INDEX(BUFFER(out), i), │
ADD(LOAD(INDEX(BUFFER(a), i)), │
LOAD(INDEX(BUFFER(b), i)))) │
END(RANGE) ◄──────────────────────────────────┘
)
# After linearize: flat list
[RANGE(i,0,1024), LOAD(a[i]), LOAD(b[i]), ADD, STORE(out[i]), END]
# After render (OpenCL C example):
"""
kernel void r_1024(global float* out, global float* a, global float* b) {
int i = get_global_id(0);
out[i] = a[i] + b[i];
}
"""
# After compile: bytes of .ptx / .metallib / .so etc.