Vmap: Design & Implementation¶
Goal¶
Implement a vmap function that transforms a function over a batch dimension
by directly rewriting the UOp graph to add the extra batch dimension using a
PatternMatcher.
Interface¶
def vmap[**P](
fn: Callable[P, Tensor],
in_axes: int | tuple[int | None, ...] = 0,
out_axis: int = 0,
) -> Callable[P, Tensor]: ...
in_axes: Which axis of each input to map over.Nonemeans broadcast (don't map). Anintapplies to all inputs.out_axis: Where to place the batch dimension in the output.- Assert all batched dimensions have the same size (checked at runtime).
Design¶
High-level algorithm¶
- Trace: Call
fnwith placeholder (empty) tensors that have unbatched shapes. This produces an outputTensorwhose.uopgraph describes the computation. - Walk & rewrite: Walk the output UOp graph in topological order (from inputs
toward output). For each node, apply a
PatternMatcherto produce the "batched" version of that node. - Context: A
dict[UOp, tuple[UOp, bool]]maps each original UOp to(batched_uop, has_batch): has_batch=True:batched_uophas shape(batch_size, *original_shape)has_batch=False:batched_uophas the original shape (constant/broadcast)- Output: Extract the final batched UOp, optionally move the batch axis to
out_axis, and wrap in aTensor.
Context object¶
Helper to look up a source UOp:
If a UOp is not in the cache, it's unbatched (e.g., a constant or infrastructure node).
Rules by op category¶
Inputs¶
Seeded before the walk. Batched inputs: (batched_uop, True).
Non-batched inputs: (original_uop, False).
Unary elementwise (SIN, EXP2, LOG2, SQRT, NEG, RECIPROCAL, CAST)¶
Same batch status as source. Apply the same Tensor op to the (possibly batched) source.
Binary/ternary elementwise (ADD, SUB, MUL, MAX, POW, WHERE, CMP*)¶
If mixing batched and unbatched: broadcast the unbatched one by
unsqueeze(0).expand(batch_size, ...).
Result is batched if any input is batched.
Movement ops (RESHAPE, PERMUTE, EXPAND, PAD, SHRINK, FLIP)¶
- Unbatched: apply as-is.
- Batched: adjust for batch dim at position 0:
- RESHAPE: target =
(batch_size, *original_target_shape) - PERMUTE: perm =
(0, *(p+1 for p in original_perm)) - EXPAND: target =
(batch_size, *original_target_shape) - PAD: padding =
((0,0), *original_padding) - SHRINK: shrink =
((0, batch_size), *original_shrink) - FLIP: mask =
(False, *original_mask)
CONTIGUOUS / CONTIGUOUS_BACKWARD¶
Pass through batch status, apply contiguous to batched source.
REDUCE_AXIS¶
- Unbatched: apply as-is.
- Batched: shift reduction axes by +1 (skip batch dim at 0).
Output axis handling¶
After the walk, if the output is unbatched (has_batch=False), broadcast it to
add the batch dimension (unsqueeze(0).expand(batch_size, ...)). This matches
JAX semantics where vmap always produces a batched output, even if the
result is constant w.r.t. all batched inputs.
Then, if out_axis != 0, move the batch axis from 0 to out_axis
using Tensor.permute.
Implementation Notes¶
File structure¶
src/anvil/transform/_vmap.py— Core implementation (~220 LOC)tests/test_vmap.py— Test suite
How it works¶
- The
vmappedwrapper normalizes/validatesin_axes, moves each batched input's batch axis to position 0, and creates unbatched placeholder tensors. - Calls
fn(*placeholders)to trace the computation graph. - Seeds a
VmapCtx(batch_size + cache mappingUOp → (UOp, bool)) with the placeholder-to-batched-input mappings. - Uses
walk(output.uop, input_uops)to traverse the graph in topological order. - For each node,
pm_vmap.rewrite(node, ctx=ctx)dispatches to the appropriate rule, which constructs the batched UOp via Tensor operations (for correct shape computation) and returns(new_uop, has_batch). - Finally, if
out_axis != 0, permutes the batch dim to the requested position.
Key design decisions¶
- Tensor-level construction in rules: We wrap UOps as
Tensorobjects to leverage tinygrad's shape validation and movement op helpers, then extract.uop. - No custom ops: The transformation is purely graph-to-graph, adding real dimensions that the scheduler can reason about. No fork of tinygrad needed.
_ensure_batchedfor broadcasting: When a binary op mixes batched and unbatched operands, the unbatched one isunsqueeze(0).expand(batch_size, ...).
Supported ops¶
| Category | Ops |
|---|---|
| Unary elementwise | NEG, SIN, LOG2, EXP2, SQRT, RECIPROCAL, CAST, BITCAST |
| Binary elementwise | ADD, SUB, MUL, MAX, POW, CMPLT, CMPNE, CMPEQ |
| Ternary | WHERE |
| Movement | RESHAPE, PERMUTE, EXPAND, PAD, SHRINK, FLIP |
| Contiguous | CONTIGUOUS, CONTIGUOUS_BACKWARD |
| Reduce | REDUCE_AXIS (ADD, MAX, MUL) |
Not yet supported (add as needed)¶
- Indexing ops (tensor-based indexing, gather/scatter)
- ASSIGN, DETACH
Progress¶
- [x] Core implementation in
src/anvil/transform/_vmap.py - [x] Tests in
tests/test_vmap.py - [x] Basic elementwise, broadcast, movement, reduce, validation
- [x] Nested vmap (vmap of vmap)
- [x] Integration with JVP (vmap over jvp for Jacobians)
- [ ] Add indexing op support if needed