Skip to content

Vmap: Design & Implementation

Goal

Implement a vmap function that transforms a function over a batch dimension by directly rewriting the UOp graph to add the extra batch dimension using a PatternMatcher.

Interface

def vmap[**P](
    fn: Callable[P, Tensor],
    in_axes: int | tuple[int | None, ...] = 0,
    out_axis: int = 0,
) -> Callable[P, Tensor]: ...
  • in_axes: Which axis of each input to map over. None means broadcast (don't map). An int applies to all inputs.
  • out_axis: Where to place the batch dimension in the output.
  • Assert all batched dimensions have the same size (checked at runtime).

Design

High-level algorithm

  1. Trace: Call fn with placeholder (empty) tensors that have unbatched shapes. This produces an output Tensor whose .uop graph describes the computation.
  2. Walk & rewrite: Walk the output UOp graph in topological order (from inputs toward output). For each node, apply a PatternMatcher to produce the "batched" version of that node.
  3. Context: A dict[UOp, tuple[UOp, bool]] maps each original UOp to (batched_uop, has_batch):
  4. has_batch=True: batched_uop has shape (batch_size, *original_shape)
  5. has_batch=False: batched_uop has the original shape (constant/broadcast)
  6. Output: Extract the final batched UOp, optionally move the batch axis to out_axis, and wrap in a Tensor.

Context object

@dataclass
class VmapCtx:
    batch_size: int
    cache: dict[UOp, tuple[UOp, bool]]

Helper to look up a source UOp:

def _get(ctx: VmapCtx, u: UOp) -> tuple[UOp, bool]:
    return ctx.cache.get(u, (u, False))

If a UOp is not in the cache, it's unbatched (e.g., a constant or infrastructure node).

Rules by op category

Inputs

Seeded before the walk. Batched inputs: (batched_uop, True). Non-batched inputs: (original_uop, False).

Unary elementwise (SIN, EXP2, LOG2, SQRT, NEG, RECIPROCAL, CAST)

Same batch status as source. Apply the same Tensor op to the (possibly batched) source.

Binary/ternary elementwise (ADD, SUB, MUL, MAX, POW, WHERE, CMP*)

If mixing batched and unbatched: broadcast the unbatched one by unsqueeze(0).expand(batch_size, ...). Result is batched if any input is batched.

Movement ops (RESHAPE, PERMUTE, EXPAND, PAD, SHRINK, FLIP)

  • Unbatched: apply as-is.
  • Batched: adjust for batch dim at position 0:
  • RESHAPE: target = (batch_size, *original_target_shape)
  • PERMUTE: perm = (0, *(p+1 for p in original_perm))
  • EXPAND: target = (batch_size, *original_target_shape)
  • PAD: padding = ((0,0), *original_padding)
  • SHRINK: shrink = ((0, batch_size), *original_shrink)
  • FLIP: mask = (False, *original_mask)

CONTIGUOUS / CONTIGUOUS_BACKWARD

Pass through batch status, apply contiguous to batched source.

REDUCE_AXIS

  • Unbatched: apply as-is.
  • Batched: shift reduction axes by +1 (skip batch dim at 0).

Output axis handling

After the walk, if the output is unbatched (has_batch=False), broadcast it to add the batch dimension (unsqueeze(0).expand(batch_size, ...)). This matches JAX semantics where vmap always produces a batched output, even if the result is constant w.r.t. all batched inputs.

Then, if out_axis != 0, move the batch axis from 0 to out_axis using Tensor.permute.

Implementation Notes

File structure

  • src/anvil/transform/_vmap.py — Core implementation (~220 LOC)
  • tests/test_vmap.py — Test suite

How it works

  1. The vmapped wrapper normalizes/validates in_axes, moves each batched input's batch axis to position 0, and creates unbatched placeholder tensors.
  2. Calls fn(*placeholders) to trace the computation graph.
  3. Seeds a VmapCtx (batch_size + cache mapping UOp → (UOp, bool)) with the placeholder-to-batched-input mappings.
  4. Uses walk(output.uop, input_uops) to traverse the graph in topological order.
  5. For each node, pm_vmap.rewrite(node, ctx=ctx) dispatches to the appropriate rule, which constructs the batched UOp via Tensor operations (for correct shape computation) and returns (new_uop, has_batch).
  6. Finally, if out_axis != 0, permutes the batch dim to the requested position.

Key design decisions

  • Tensor-level construction in rules: We wrap UOps as Tensor objects to leverage tinygrad's shape validation and movement op helpers, then extract .uop.
  • No custom ops: The transformation is purely graph-to-graph, adding real dimensions that the scheduler can reason about. No fork of tinygrad needed.
  • _ensure_batched for broadcasting: When a binary op mixes batched and unbatched operands, the unbatched one is unsqueeze(0).expand(batch_size, ...).

Supported ops

Category Ops
Unary elementwise NEG, SIN, LOG2, EXP2, SQRT, RECIPROCAL, CAST, BITCAST
Binary elementwise ADD, SUB, MUL, MAX, POW, CMPLT, CMPNE, CMPEQ
Ternary WHERE
Movement RESHAPE, PERMUTE, EXPAND, PAD, SHRINK, FLIP
Contiguous CONTIGUOUS, CONTIGUOUS_BACKWARD
Reduce REDUCE_AXIS (ADD, MAX, MUL)

Not yet supported (add as needed)

  • Indexing ops (tensor-based indexing, gather/scatter)
  • ASSIGN, DETACH

Progress

  • [x] Core implementation in src/anvil/transform/_vmap.py
  • [x] Tests in tests/test_vmap.py
  • [x] Basic elementwise, broadcast, movement, reduce, validation
  • [x] Nested vmap (vmap of vmap)
  • [x] Integration with JVP (vmap over jvp for Jacobians)
  • [ ] Add indexing op support if needed