spjacobian Scalability Analysis¶
Problem Statement¶
spjacobian applied to large functions (e.g., eq_constraints_fn2 of the bicycle OCP with N=50) fails because tinygrad's scheduler doesn't scale with graph size. The scheduler's graph_rewrite hits the REWRITE_STACK_LIMIT (250K) at ~3500 UOp nodes.
Quantified Bottleneck¶
For eq_constraints_fn2 (bicycle OCP, N=50, 305 inputs → 204 outputs, 8 colors, 1204 nnz):
| Component | UOp count |
|---|---|
| Base function graph | 169 |
| One JVP (generic tangent) | 367 |
| One JVP (CONST-based basis column) | 2,185 |
| All 8 JVPs stacked | 14,101 |
| + uncompression (1204 gathers) | 21,410 |
The CONST-based basis explodes the graph because Tensor.stack(*[Tensor.full((), v) for v in 305_values]) creates ~6 UOps per element (CONST + RESHAPE + PAD_bounds + PAD + ADD chain) = ~1,830 UOps per basis column.
The scheduler breaks at N=10 (3,616 nodes). N=7 (3,071 nodes) succeeds in ~5s.
Scaling comparison¶
| Approach | Graph size | Codegen time (N=50) | Code quality |
|---|---|---|---|
spjacobian(unroll=True) |
O(n_colors × n_inputs + n_nnz) | FAILS | Best (DCE) |
spjacobian(unroll=False) (vmap) |
O(function_graph) | 0.12s | Good (no DCE) |
Per-block (eq_constraints_jac_fn) |
O(block_graph + assembly) | 0.32s | Best (per-block DCE) |
Root Causes¶
1. tinygrad doesn't fold constants at UOp creation time¶
x = Tensor.empty(5)
zero = x.uop.const_like(0)
result = zero * x.uop # Creates MUL UOp, NOT folded to CONST(0)
result2 = zero + x.uop # Creates ADD UOp, NOT folded to x
These are only folded during graph_rewrite with the symbolic pattern matcher — which happens in the scheduler, AFTER the graph is fully built.
2. Tensor-level operations prevent scalar zero-skipping¶
The function operates on batch tensors (e.g., dynamics on (4, 50) inputs). Even when the tangent vector has many zeros at the scalar level, the tangent of a (4, 50) tensor is a single UOp — NOT 200 individual scalar UOps. We can't detect per-element zeros at the UOp level.
This is the fundamental mismatch with CASADi: CASADi works at the scalar expression level (MX/SX) where each element is a separate node, enabling per-element zero-skipping. Tinygrad works at the tensor level where batch operations are single nodes.
3. Tensor.stack of CONSTs is extremely expensive¶
Each element in the basis creates a CONST→RESHAPE→PAD→ADD chain. For 305 elements × 8 colors = 2,440 scalars, this creates ~14,640 UOps just for the basis representation (before any JVP computation).
Strategies¶
Strategy 1: JVP-level constant folding¶
Even though batch-level zero-skipping doesn't work, tensor-level zero-skipping does help at block boundaries. For example, for a basis column that doesn't affect gamma, the tangent of gamma is CONST(0), and all downstream operations involving only gamma's tangent can be eliminated.
Implementation: Modify JVP rules to check for CONST(0) tangents:
def _is_zero(u: UOp) -> bool:
return u.op is Ops.CONST and u.arg == 0
def _mul_jvp(ctx, ret, x, y):
dx, dy = _tan(ctx, x), _tan(ctx, y)
if _is_zero(dx) and _is_zero(dy): return _zero_like(ret)
if _is_zero(dx): return x * dy
if _is_zero(dy): return y * dx
return y * dx + x * dy
Impact assessment: For the bicycle OCP with 8 colors, the tangent of gamma is CONST(0) for ~7 out of 8 colors. This eliminates gamma-related derivative computations, reducing the graph by roughly ~15% per JVP call. Helpful but not sufficient alone.
Strategy 2: Compact basis representation¶
Replace Tensor.stack(*[Tensor.full((), v) for v in 305_values]) with more compact UOp representations:
Option A: Buffer-backed basis
- Pro: Only ~5 UOps for the basis - Con: Scheduler can't see values → no dead code eliminationOption B: Segmented CONST blocks Instead of 305 individual CONSTs, group consecutive identical values:
# If basis = [0,0,0,1,1,0,0,0,1,0,...], represent as:
# zeros(3).cat(ones(2)).cat(zeros(3)).cat(ones(1)).cat(zeros(1))...
Strategy 3: Increase scheduler robustness¶
3a. Raise REWRITE_STACK_LIMIT in the fork. This might allow larger graphs but with O(N²) codegen time.
3b. Pre-simplification pass: Before scheduling, run a focused constant-folding pass:
pm_pre_simplify = PatternMatcher([
(UPat(Ops.MUL, src=(UPat.cvar("c"), UPat.var("x"))), lambda c, x: x.const_like(0) if c.arg == 0 else None),
(UPat(Ops.ADD, src=(UPat.cvar("c"), UPat.var("x"))), lambda c, x: x if c.arg == 0 else None),
])
sink = graph_rewrite(sink, pm_pre_simplify)
Strategy 4: Per-column scheduling with caching¶
Instead of building one giant graph with all colors stacked, build and schedule each color's JVP independently:
for j in range(ncolors):
jvp_j = jvp((output,), (x,), (basis[:, j],))[0]
schedule_j = schedule(jvp_j) # Small graph → fast
Then combine the scheduled results. This keeps each scheduling call within the ~3000 node limit.
Recommended Action Plan¶
Phase 1: Quick wins¶
- JVP constant folding (Strategy 1): Add
_is_zerochecks to all JVP rules. Small code change, ~15% graph reduction. - Pre-simplification pass (Strategy 3b): Fold
CONST(0) * xandCONST(0) + xbefore scheduling.
Phase 2: Tinygrad improvements (long-term)¶
- Early constant folding: Fold
CONST(0) * x,CONST(0) + xat UOp creation time. - Compact basis UOps: More efficient representation for sparse constant tensors.
- Range specialization: Teach the codegen to specialize loop bodies when a RANGE variable appears in constant comparisons.
Appendix: Stack Limit Experiments¶
Raising REWRITE_STACK_LIMIT from 250K to 5M allows larger graphs to be processed:
| N | UOps | Default (250K) | 5M limit | Code lines |
|---|---|---|---|---|
| 7 | 3,071 | 5.2s ✓ | - | ~300 |
| 10 | 3,616 | FAILS | 10.0s ✓ | ~500 |
| 15 | 4,521 | FAILS | 18.4s ✓ | 877 |
| 20 | - | FAILS | 31.8s ✓ | 1,157 |
| 30 | - | FAILS | 75.0s ✓ | 1,717 |
| 50 | 21,410 | FAILS | FAILS | - |
The scaling is approximately O(N²) in codegen time, and the generated code is a flat sequence of scalar operations with NO loops — every stage is fully inlined.
By contrast, the per-block approach (N=50) generates in 0.32s with 226 lines containing clean for loops. The vmap approach generates in 0.12s with 145 lines. Both are superior in every metric.
Appendix: Generated Code Comparison¶
vmap approach (N=50, 145 lines, 0.12s)¶
for (stage = 0..50): // batch dim
load primal values[stage]
for (color = 0..7): // color dim
load basis[stage, color] from buffer
compute JVP(primal, basis)
store result
Per-block approach (N=50, 226 lines, 0.32s)¶
// Kernel 1: compute per-stage unrolled Jacobian
for (stage = 0..49):
load primal values[stage]
for (color = 0..7):
// basis is implicit via (color == k) checks
compute Jacobian entries with dead-code elimination
store compressed result
// Kernel 2: assemble global CSC from per-stage blocks
for (stage = 0..49):
copy from compressed to global positions
(color == k) conditionals → per-color specialization
- Loop structure maintained → scalable
- Dead-code elimination within each stage → efficient per-iteration code