Skip to content

Transforms

vmap = svmap module-attribute

dvmap(fn, in_axes=0, out_axis=0)

Dynamically vectorize a function by lazily determining the batch size from inputs.

Unlike svmap/vmap, the UOp graph is rebuilt on each call, which allows the batch size to vary. Useful inside traced contexts (e.g. AD operators) where the batch size is not known at graph-construction time.

Parameters:

Name Type Description Default
fn Callable[P, Tensor]

Callable accepting Tensors and returning a single Tensor.

required
in_axes int | tuple[int | None, ...]

Which axis is the batch dimension for each input. An int applies to all inputs; None means that input is not batched (broadcast).

0
out_axis int

Where to place the batch dimension in the output.

0
Source code in src/anvil/transform/_vmap.py
def dvmap[**P](fn: Callable[P, Tensor], in_axes: int | tuple[int | None, ...] = 0, out_axis: int = 0) -> Callable[P, Tensor]:
  """Dynamically vectorize a function by lazily determining the batch size from inputs.

  Unlike `svmap`/`vmap`, the UOp graph is rebuilt on each call, which allows the batch
  size to vary. Useful inside traced contexts (e.g. AD operators) where the batch size
  is not known at graph-construction time.

  Args:
    fn: Callable accepting Tensors and returning a single Tensor.
    in_axes: Which axis is the batch dimension for each input. An int applies to all
        inputs; ``None`` means that input is not batched (broadcast).
    out_axis: Where to place the batch dimension in the output.
  """

  def vmapped(*args: Tensor) -> Tensor:
    assert all(isinstance(a, Tensor) for a in args)
    validated_axes = _validate_in_axes(in_axes, tuple(Arg(cast(Shape, a.shape), a.dtype) for a in args))

    # determine batch size
    batch_sizes = [a.shape[ax] for a, ax in zip(args, validated_axes) if ax is not None]
    if not batch_sizes:
      raise ValueError("at least one input must have a mapped axis")
    if not all(s == batch_sizes[0] for s in batch_sizes):
      raise ValueError(f"inconsistent batch sizes: {batch_sizes}")
    batch_size = batch_sizes[0]

    # build placeholders (unbatched) and seed context with batched inputs
    placeholders: list[Tensor] = []
    ctx = VmapCtx(batch_size=batch_size)
    for a, ax in zip(args, validated_axes):
      if ax is not None:
        batched = a.permute(*([ax] + [i for i in range(a.ndim) if i != ax])) if ax != 0 else a
        ph = Tensor.empty(*batched.shape[1:], dtype=a.dtype)
        placeholders.append(ph)
        ctx[ph.uop] = (batched.uop, True)
      else:
        ph = Tensor.empty(*a.shape, dtype=a.dtype)
        placeholders.append(ph)
        ctx[ph.uop] = (a.uop, False)

    # trace and rewrite (resolve any CALL ops from composed NumericalFunctions)
    output = fn(*placeholders)  # ty:ignore[missing-argument, invalid-argument-type]
    resolved_uop = resolve_calls(output.uop)
    walk_rewrite(resolved_uop, pm_vmap, ctx)

    # extract output — always add batch dim (like JAX), even for constants
    bout, has_batch = ctx[resolved_uop]
    if not has_batch:
      bout = bout.unsqueeze(0).expand(batch_size, *bout.shape)
    if out_axis != 0:
      ax = (out_axis % Tensor(bout).ndim) if out_axis < 0 else out_axis
      perm = list(range(Tensor(bout).ndim))
      perm.insert(ax, perm.pop(0))
      bout = bout.permute(*perm)
    return Tensor(bout.simplify())

  return vmapped  # ty:ignore[invalid-return-type]

svmap(fn, batch_size, in_axes=0, out_axis=0)

Statically vectorize a NumericalFunction with a fixed batch size.

The UOp graph is rewritten once at construction time and reused on each call via UOp.substitute(), making it more efficient than dvmap for repeated evaluation with the same batch size.

Parameters:

Name Type Description Default
fn NumericalFunction

The NumericalFunction to vectorize (must have a single output).

required
batch_size int

Fixed batch dimension size.

required
in_axes int | tuple[int | None, ...]

Which axis is the batch dimension for each input. An int applies to all inputs; None means that input is not batched (broadcast).

0
out_axis int

Where to place the batch dimension in the output.

0
Source code in src/anvil/transform/_vmap.py
def svmap(fn: NumericalFunction, batch_size: int, in_axes: int | tuple[int | None, ...] = 0, out_axis: int = 0) -> NumericalFunction:
  """Statically vectorize a NumericalFunction with a fixed batch size.

  The UOp graph is rewritten once at construction time and reused on each call
  via ``UOp.substitute()``, making it more efficient than `dvmap` for repeated
  evaluation with the same batch size.

  Args:
    fn: The NumericalFunction to vectorize (must have a single output).
    batch_size: Fixed batch dimension size.
    in_axes: Which axis is the batch dimension for each input. An int applies to all
        inputs; ``None`` means that input is not batched (broadcast).
    out_axis: Where to place the batch dimension in the output.
  """
  fg = fn.function_graph

  # validate in_axes against unbatched shapes (+1 for the batch dim)
  n = len(fn.inputs)
  if isinstance(in_axes, int):
    in_axes = (in_axes,) * n
  assert len(in_axes) == n, f"in_axes length {len(in_axes)} != number of args {n}"
  for i, (ax, arg) in enumerate(zip(in_axes, fn.inputs)):
    assert ax is None or -arg.ndim - 1 <= ax < arg.ndim + 1, f"axis {ax} out of bounds for arg {i} with {arg.ndim + 1} dims"
  in_axes = cast(More[int], tuple(ax if ax is None else ax % (a.ndim + 1) for (ax, a) in zip(in_axes, fn.inputs)))

  # validate out_axis — infer output Arg from traced tensor
  assert len(fn.outputs) == 1
  outarg = fn.outputs[0]
  assert -(outarg.ndim + 1) <= out_axis < (outarg.ndim + 1)
  out_axis = out_axis % (outarg.ndim + 1)

  # create batched input args
  batched_inputs = tuple(
    Arg(arg.shape if ax is None else insert(make_more(arg.shape), batch_size, ax), arg.dtype) for ax, arg in zip(in_axes, fn.inputs)
  )
  batched_placeholders = tuple(Tensor.empty(make_more(arg.shape), dtype=arg.dtype) for arg in batched_inputs)

  # create ctx with permuted batched placeholders
  ctx = VmapCtx(
    batch_size=batch_size,
    initial={
      ph.uop: ((ph.uop, False) if ax is None else (bph.uop if ax == 0 else bph.uop.permute(*([ax] + [i for i in range(bph.ndim) if i != ax])), True))
      for ax, ph, bph in zip(in_axes, fg._trace_inputs, batched_placeholders)
    },
  )

  # trace and rewrite (resolve any CALL ops from composed NumericalFunctions)
  resolved_output_uop = resolve_calls(fg._trace_outputs[0].uop)
  res = walk_rewrite(resolved_output_uop, pm_vmap, ctx)
  # constant functions (output independent of inputs) aren't rewritten at all —
  # this matters for gradients of functions that are constant w.r.t. one of their inputs
  if res is None:
    bout = Tensor(resolved_output_uop).unsqueeze(0).expand(batch_size, *make_more(outarg.shape)).uop
  else:
    bout, is_batched = res
    if not is_batched:
      bout = Tensor(bout).unsqueeze(0).expand(batch_size, *make_more(outarg.shape)).uop

  # handle out_axis
  if out_axis != 0:
    t = Tensor(bout)
    perm = list(range(t.ndim))
    perm.insert(out_axis, perm.pop(0))
    bout = t.permute(*perm).uop

  # wrapper: substitute placeholders with actual args
  def wrapper(*args) -> Tensor:
    subs: dict[UOp, UOp] = {
      (ph.uop if ax is None else bph.uop): a.uop for bph, ph, a, ax in zip(batched_placeholders, fg._trace_inputs, args, in_axes)
    }
    return Tensor(bout.substitute(subs))

  res = NumericalFunction("v" + fn.name, wrapper, batched_inputs)  # ty: ignore[no-matching-overload]
  return res