Skip to content

spjacobian Scalability Analysis

Problem Statement

spjacobian applied to large functions (e.g., eq_constraints_fn2 of the bicycle OCP with N=50) fails because tinygrad's scheduler doesn't scale with graph size. The scheduler's graph_rewrite hits the REWRITE_STACK_LIMIT (250K) at ~3500 UOp nodes.

Quantified Bottleneck

For eq_constraints_fn2 (bicycle OCP, N=50, 305 inputs → 204 outputs, 8 colors, 1204 nnz):

Component UOp count
Base function graph 169
One JVP (generic tangent) 367
One JVP (CONST-based basis column) 2,185
All 8 JVPs stacked 14,101
+ uncompression (1204 gathers) 21,410

The CONST-based basis explodes the graph because Tensor.stack(*[Tensor.full((), v) for v in 305_values]) creates ~6 UOps per element (CONST + RESHAPE + PAD_bounds + PAD + ADD chain) = ~1,830 UOps per basis column.

The scheduler breaks at N=10 (3,616 nodes). N=7 (3,071 nodes) succeeds in ~5s.

Scaling comparison

Approach Graph size Codegen time (N=50) Code quality
spjacobian(unroll=True) O(n_colors × n_inputs + n_nnz) FAILS Best (DCE)
spjacobian(unroll=False) (vmap) O(function_graph) 0.12s Good (no DCE)
Per-block (eq_constraints_jac_fn) O(block_graph + assembly) 0.32s Best (per-block DCE)

Root Causes

1. tinygrad doesn't fold constants at UOp creation time

x = Tensor.empty(5)
zero = x.uop.const_like(0)
result = zero * x.uop  # Creates MUL UOp, NOT folded to CONST(0)
result2 = zero + x.uop  # Creates ADD UOp, NOT folded to x

These are only folded during graph_rewrite with the symbolic pattern matcher — which happens in the scheduler, AFTER the graph is fully built.

2. Tensor-level operations prevent scalar zero-skipping

The function operates on batch tensors (e.g., dynamics on (4, 50) inputs). Even when the tangent vector has many zeros at the scalar level, the tangent of a (4, 50) tensor is a single UOp — NOT 200 individual scalar UOps. We can't detect per-element zeros at the UOp level.

This is the fundamental mismatch with CASADi: CASADi works at the scalar expression level (MX/SX) where each element is a separate node, enabling per-element zero-skipping. Tinygrad works at the tensor level where batch operations are single nodes.

3. Tensor.stack of CONSTs is extremely expensive

Each element in the basis creates a CONST→RESHAPE→PAD→ADD chain. For 305 elements × 8 colors = 2,440 scalars, this creates ~14,640 UOps just for the basis representation (before any JVP computation).


Strategies

Strategy 1: JVP-level constant folding

Even though batch-level zero-skipping doesn't work, tensor-level zero-skipping does help at block boundaries. For example, for a basis column that doesn't affect gamma, the tangent of gamma is CONST(0), and all downstream operations involving only gamma's tangent can be eliminated.

Implementation: Modify JVP rules to check for CONST(0) tangents:

def _is_zero(u: UOp) -> bool:
    return u.op is Ops.CONST and u.arg == 0

def _mul_jvp(ctx, ret, x, y):
    dx, dy = _tan(ctx, x), _tan(ctx, y)
    if _is_zero(dx) and _is_zero(dy): return _zero_like(ret)
    if _is_zero(dx): return x * dy
    if _is_zero(dy): return y * dx
    return y * dx + x * dy

Impact assessment: For the bicycle OCP with 8 colors, the tangent of gamma is CONST(0) for ~7 out of 8 colors. This eliminates gamma-related derivative computations, reducing the graph by roughly ~15% per JVP call. Helpful but not sufficient alone.

Strategy 2: Compact basis representation

Replace Tensor.stack(*[Tensor.full((), v) for v in 305_values]) with more compact UOp representations:

Option A: Buffer-backed basis

basis_col = Tensor(realized_basis_column)  # single BUFFER UOp
- Pro: Only ~5 UOps for the basis - Con: Scheduler can't see values → no dead code elimination

Option B: Segmented CONST blocks Instead of 305 individual CONSTs, group consecutive identical values:

# If basis = [0,0,0,1,1,0,0,0,1,0,...], represent as:
# zeros(3).cat(ones(2)).cat(zeros(3)).cat(ones(1)).cat(zeros(1))...
For 8 colors with ~44 nonzeros each spread across 305 positions, segments ≈ 88 → ~264 UOps vs ~1830.

Strategy 3: Increase scheduler robustness

3a. Raise REWRITE_STACK_LIMIT in the fork. This might allow larger graphs but with O(N²) codegen time.

3b. Pre-simplification pass: Before scheduling, run a focused constant-folding pass:

pm_pre_simplify = PatternMatcher([
    (UPat(Ops.MUL, src=(UPat.cvar("c"), UPat.var("x"))), lambda c, x: x.const_like(0) if c.arg == 0 else None),
    (UPat(Ops.ADD, src=(UPat.cvar("c"), UPat.var("x"))), lambda c, x: x if c.arg == 0 else None),
])
sink = graph_rewrite(sink, pm_pre_simplify)

Strategy 4: Per-column scheduling with caching

Instead of building one giant graph with all colors stacked, build and schedule each color's JVP independently:

for j in range(ncolors):
    jvp_j = jvp((output,), (x,), (basis[:, j],))[0]
    schedule_j = schedule(jvp_j)  # Small graph → fast

Then combine the scheduled results. This keeps each scheduling call within the ~3000 node limit.


Phase 1: Quick wins

  1. JVP constant folding (Strategy 1): Add _is_zero checks to all JVP rules. Small code change, ~15% graph reduction.
  2. Pre-simplification pass (Strategy 3b): Fold CONST(0) * x and CONST(0) + x before scheduling.

Phase 2: Tinygrad improvements (long-term)

  1. Early constant folding: Fold CONST(0) * x, CONST(0) + x at UOp creation time.
  2. Compact basis UOps: More efficient representation for sparse constant tensors.
  3. Range specialization: Teach the codegen to specialize loop bodies when a RANGE variable appears in constant comparisons.

Appendix: Stack Limit Experiments

Raising REWRITE_STACK_LIMIT from 250K to 5M allows larger graphs to be processed:

N UOps Default (250K) 5M limit Code lines
7 3,071 5.2s ✓ - ~300
10 3,616 FAILS 10.0s ✓ ~500
15 4,521 FAILS 18.4s ✓ 877
20 - FAILS 31.8s ✓ 1,157
30 - FAILS 75.0s ✓ 1,717
50 21,410 FAILS FAILS -

The scaling is approximately O(N²) in codegen time, and the generated code is a flat sequence of scalar operations with NO loops — every stage is fully inlined.

By contrast, the per-block approach (N=50) generates in 0.32s with 226 lines containing clean for loops. The vmap approach generates in 0.12s with 145 lines. Both are superior in every metric.

Appendix: Generated Code Comparison

vmap approach (N=50, 145 lines, 0.12s)

for (stage = 0..50):        // batch dim
    load primal values[stage]
    for (color = 0..7):     // color dim
        load basis[stage, color] from buffer
        compute JVP(primal, basis)
        store result
- Loads basis from buffer at runtime → no dead-code elimination per color - But clean loop structure, good cache behavior

Per-block approach (N=50, 226 lines, 0.32s)

// Kernel 1: compute per-stage unrolled Jacobian
for (stage = 0..49):
    load primal values[stage]
    for (color = 0..7):
        // basis is implicit via (color == k) checks
        compute Jacobian entries with dead-code elimination
        store compressed result

// Kernel 2: assemble global CSC from per-stage blocks
for (stage = 0..49):
    copy from compressed to global positions
- Basis is encoded as (color == k) conditionals → per-color specialization - Loop structure maintained → scalable - Dead-code elimination within each stage → efficient per-iteration code

// Single flat kernel, no loops
val0 = input[0]
val1 = input[1]
...
alu0 = sin(val0)
alu1 = cos(val1)
... (1000+ lines of scalar ops)
output[483] = (0.166... * (alu475 + ...))
- Everything inlined → O(N × colors × ops_per_stage) scalar operations - Massive register pressure, poor cache behavior - Code size grows linearly with N