Skip to content

Core

Arg dataclass

Specification of a NumericalFunction input: shape and dtype.

Parameters:

Name Type Description Default
shape OneOrMore[int]

Dimensions of the input tensor. A single int for 1D, a tuple for multi-dimensional.

required
dtype DTypeLike

Data type (default: float64).

float64
Source code in src/anvil/codegen/numerical_function.py
@dataclass(frozen=True)
class Arg:
  """Specification of a NumericalFunction input: shape and dtype.

  Args:
    shape: Dimensions of the input tensor. A single int for 1D, a tuple for multi-dimensional.
    dtype: Data type (default: float64).
  """

  shape: OneOrMore[int]
  dtype: DTypeLike = dtypes.float64

  @property
  def ndim(self) -> int:
    return len(make_more(self.shape))

CodegenIntConstant dataclass

A compile-time integer constant emitted as constexpr in generated C++ code.

Supports arithmetic (+, -, *, //) that propagates symbolic expressions and dependency tracking for correct declaration ordering.

Parameters:

Name Type Description Default
name str

Identifier in generated code (also used for the symbolic expression).

required
value int

Concrete integer value.

required
type Literal['int', 'int32_t', 'int64_t']

C++ integer type to emit.

'int'
Source code in src/anvil/codegen/common.py
@dataclass(frozen=True)
class CodegenIntConstant:
  """A compile-time integer constant emitted as ``constexpr`` in generated C++ code.

  Supports arithmetic (``+``, ``-``, ``*``, ``//``) that propagates symbolic
  expressions and dependency tracking for correct declaration ordering.

  Args:
    name: Identifier in generated code (also used for the symbolic expression).
    value: Concrete integer value.
    type: C++ integer type to emit.
  """

  name: str
  value: int
  type: Literal["int", "int32_t", "int64_t"] = "int"
  _dependencies: set[Self] = field(default_factory=set, compare=False, hash=False)

  def __int__(self) -> int:
    return self.value

  def __repr__(self) -> str:
    return self.name

  def dependencies(self) -> list[Self]:
    """Returns immediate dependencies (for toposort)."""
    return list(self._dependencies)

  def all_dependencies(self) -> set[Self]:
    """Returns all transitive dependencies including self."""
    result = {self}
    for dep in self._dependencies:
      result.update(dep.all_dependencies())
    return result

  @staticmethod
  def from_expr(name: str, expr: "CodegenIntConstant") -> "CodegenIntConstant":
    """Create a named constant from an expression, preserving only base dependencies (no intermediate expressions)."""
    # Collect all dependencies but filter out unnamed intermediate expressions
    all_deps = expr.all_dependencies()
    # Keep only constants that have simple names (no parentheses = base constants or properly named derived constants)
    base_deps = set(dep for dep in all_deps if "(" not in dep.name and dep != expr)
    return CodegenIntConstant(name=name, value=int(expr), type=expr.type, _dependencies=base_deps)

  def __add__(self, other: "CodegenIntConstant | int") -> "CodegenIntConstant":
    if isinstance(other, int):
      return CodegenIntConstant(name=f"({self.name} + {other})", value=self.value + other, type=self.type, _dependencies={self})
    return CodegenIntConstant(name=f"({self.name} + {other.name})", value=self.value + other.value, type=self.type, _dependencies={self, other})

  def __radd__(self, other: int) -> "CodegenIntConstant":
    return self.__add__(other)

  def __sub__(self, other: "CodegenIntConstant | int") -> "CodegenIntConstant":
    if isinstance(other, int):
      return CodegenIntConstant(name=f"({self.name} - {other})", value=self.value - other, type=self.type, _dependencies={self})
    return CodegenIntConstant(name=f"({self.name} - {other.name})", value=self.value - other.value, type=self.type, _dependencies={self, other})

  def __rsub__(self, other: int) -> "CodegenIntConstant":
    return CodegenIntConstant(name=f"({other} - {self.name})", value=other - self.value, type=self.type, _dependencies={self})

  def __mul__(self, other: "CodegenIntConstant | int") -> "CodegenIntConstant":
    if isinstance(other, int):
      return CodegenIntConstant(name=f"({self.name} * {other})", value=self.value * other, type=self.type, _dependencies={self})
    return CodegenIntConstant(name=f"({self.name} * {other.name})", value=self.value * other.value, type=self.type, _dependencies={self, other})

  def __rmul__(self, other: int) -> "CodegenIntConstant":
    return self.__mul__(other)

  def __floordiv__(self, other: "CodegenIntConstant | int") -> "CodegenIntConstant":
    if isinstance(other, int):
      return CodegenIntConstant(name=f"({self.name} / {other})", value=self.value // other, type=self.type, _dependencies={self})
    return CodegenIntConstant(name=f"({self.name} / {other.name})", value=self.value // other.value, type=self.type, _dependencies={self, other})

dependencies()

Returns immediate dependencies (for toposort).

Source code in src/anvil/codegen/common.py
def dependencies(self) -> list[Self]:
  """Returns immediate dependencies (for toposort)."""
  return list(self._dependencies)

all_dependencies()

Returns all transitive dependencies including self.

Source code in src/anvil/codegen/common.py
def all_dependencies(self) -> set[Self]:
  """Returns all transitive dependencies including self."""
  result = {self}
  for dep in self._dependencies:
    result.update(dep.all_dependencies())
  return result

from_expr(name, expr) staticmethod

Create a named constant from an expression, preserving only base dependencies (no intermediate expressions).

Source code in src/anvil/codegen/common.py
@staticmethod
def from_expr(name: str, expr: "CodegenIntConstant") -> "CodegenIntConstant":
  """Create a named constant from an expression, preserving only base dependencies (no intermediate expressions)."""
  # Collect all dependencies but filter out unnamed intermediate expressions
  all_deps = expr.all_dependencies()
  # Keep only constants that have simple names (no parentheses = base constants or properly named derived constants)
  base_deps = set(dep for dep in all_deps if "(" not in dep.name and dep != expr)
  return CodegenIntConstant(name=name, value=int(expr), type=expr.type, _dependencies=base_deps)

NumericalFunction dataclass

Bases: FunctionBase

A compiled numerical function: traces a Python/tensor function, schedules it, and JIT-compiles it to native code (C++/Metal/CUDA).

The compilation pipeline is: trace → build UOp graph → schedule → render kernels → JIT compile. Each stage is computed lazily via cached properties.

When called from Python, inputs and outputs are numpy arrays. For AOT code generation, use generate_module to emit C++ source files.

Parameters:

Name Type Description Default
name str

Unique function name (used in generated code identifiers).

required
fn Callable[..., Any]

The Python callable to trace. Must accept Tensors and return Tensor(s).

required
inputs tuple[Arg, ...]

Tuple of Arg specs describing each input's shape and dtype.

required
device str

Target device — "CPU", "METAL", or "CUDA".

'CPU'
Source code in src/anvil/codegen/numerical_function.py
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
@dataclass(frozen=True, init=False)
class NumericalFunction[I: tuple[Arg, ...], R: Tensor | tuple[Tensor, ...]](FunctionBase):
  """A compiled numerical function: traces a Python/tensor function, schedules it, and
  JIT-compiles it to native code (C++/Metal/CUDA).

  The compilation pipeline is: trace → build UOp graph → schedule → render kernels → JIT compile.
  Each stage is computed lazily via cached properties.

  When called from Python, inputs and outputs are numpy arrays. For AOT code generation,
  use `generate_module` to emit C++ source files.

  Args:
    name: Unique function name (used in generated code identifiers).
    fn: The Python callable to trace. Must accept Tensors and return Tensor(s).
    inputs: Tuple of Arg specs describing each input's shape and dtype.
    device: Target device — "CPU", "METAL", or "CUDA".
  """

  fn: Callable[..., R]
  inputs: I
  device: str = field(default="CPU", kw_only=True)

  # NOTE: overloads cover 1-4 inputs × single/dual output.
  # fmt: off
  @overload
  def __init__(self: "NumericalFunction[tuple[Arg], Tensor]", name: str, fn: Callable[[Tensor], Tensor], inputs: tuple[Arg], /, device: str = "CPU") -> None: ...
  @overload
  def __init__(self: "NumericalFunction[tuple[Arg, Arg], Tensor]", name: str, fn: Callable[[Tensor, Tensor], Tensor], inputs: tuple[Arg, Arg], /, device: str = "CPU") -> None: ...
  @overload
  def __init__(self: "NumericalFunction[tuple[Arg, Arg, Arg], Tensor]", name: str, fn: Callable[[Tensor, Tensor, Tensor], Tensor], inputs: tuple[Arg, Arg, Arg], /, device: str = "CPU") -> None: ...
  @overload
  def __init__(self: "NumericalFunction[tuple[Arg, Arg, Arg, Arg], Tensor]", name: str, fn: Callable[[Tensor, Tensor, Tensor, Tensor], Tensor], inputs: tuple[Arg, Arg, Arg, Arg], /, device: str = "CPU") -> None: ...
  @overload
  def __init__(self: "NumericalFunction[tuple[Arg], tuple[Tensor, Tensor]]", name: str, fn: Callable[[Tensor], tuple[Tensor, Tensor]], inputs: tuple[Arg], /, device: str = "CPU") -> None: ...
  @overload
  def __init__(self: "NumericalFunction[tuple[Arg, Arg], tuple[Tensor, Tensor]]", name: str, fn: Callable[[Tensor, Tensor], tuple[Tensor, Tensor]], inputs: tuple[Arg, Arg], /, device: str = "CPU") -> None: ...
  @overload
  def __init__(self: "NumericalFunction[tuple[Arg, Arg, Arg], tuple[Tensor, Tensor]]", name: str, fn: Callable[[Tensor, Tensor, Tensor], tuple[Tensor, Tensor]], inputs: tuple[Arg, Arg, Arg], /, device: str = "CPU") -> None: ...
  @overload
  def __init__(self: "NumericalFunction[tuple[Arg, Arg, Arg, Arg], tuple[Tensor, Tensor]]", name: str, fn: Callable[[Tensor, Tensor, Tensor, Tensor], tuple[Tensor, Tensor]], inputs: tuple[Arg, Arg, Arg, Arg], /, device: str = "CPU") -> None: ...
  # fmt: on
  def __init__(self, name: str, fn: Callable[..., Any], inputs: tuple[Arg, ...], device: str = "CPU") -> None:
    if device not in _SUPPORTED_DEVICES:
      raise ValueError(f"unsupported device {device!r}, must be one of {_SUPPORTED_DEVICES}")
    if device == "METAL":
      for i, arg in enumerate(inputs):
        if arg.dtype in (dtypes.float64, dtypes.double):
          raise ValueError(f"Metal does not support float64: input {i} has dtype {arg.dtype}. Use Arg(..., dtype=dtypes.float32).")
    object.__setattr__(self, "name", name)
    object.__setattr__(self, "fn", fn)
    object.__setattr__(self, "inputs", inputs)
    object.__setattr__(self, "device", device)

  @property
  def is_sparse(self):
    return False

  @cached_property
  def _renderer(self):
    match self.device:
      case "CPU":
        return CustomClangRenderer
      case "METAL":
        return _CustomMetalRenderer()
      case "CUDA":
        return _CustomCUDARenderer(arch=_detect_cuda_arch())
      case _:
        raise ValueError(f"unsupported device {self.device!r}")

  @cached_property
  def cuda_arch(self) -> str:
    """CUDA architecture string (e.g. 'sm_89') for NVRTC compilation."""
    return _detect_cuda_arch()

  ##################################################################
  # inputs / outputs
  ##################################################################

  @cached_property
  def outputs(self) -> More[Arg]:
    """Output args, inferred from tracing."""
    fg = self.function_graph
    return tuple(Arg(cast(Shape, t.shape), t.dtype) for t in fg._trace_outputs)

  ##################################################################
  # function graph (parameterized UOp body with PARAMs)
  ##################################################################

  @cached_property
  def function_graph(self) -> FnGraph:
    # 1. Fresh tracing
    trace_inputs = tuple(Tensor.empty(make_more(arg.shape), dtype=arg.dtype) for arg in self.inputs)
    raw_outputs = make_more(cast(OneOrMore[Tensor], self.fn(*trace_inputs)))
    # Ensure each output has its own buffer for codegen:
    #  - multi-dim: .contiguous() ensures C-contiguous layout (avoids PERMUTE view chains)
    #  - input-aliased: +0 forces a new buffer (views of input can't be separate C++ params,
    #    and tinygrad's scheduler won't produce a copy kernel for CONTIGUOUS on views)
    input_base_ids: set[int] = set()
    for t in trace_inputs:
      try:
        input_base_ids.add(id(t.uop.base))
      except (AttributeError, AssertionError):
        pass
    trace_outputs: tuple[Tensor, ...] = ()
    for t in raw_outputs:
      is_aliased = False
      try:
        is_aliased = id(t.uop.base) in input_base_ids
      except (AttributeError, AssertionError):
        pass
      if is_aliased:
        trace_outputs += (t + 0,)  # force new buffer via trivial arithmetic
      elif t.uop.op is not Ops.CONTIGUOUS and len(t.shape) > 1:
        trace_outputs += (t.contiguous(),)  # force C-contiguous layout
      else:
        trace_outputs += (t,)

    # 2. Build output UOp
    uret = trace_outputs[0].uop if len(trace_outputs) == 1 else UOp.maketuple(*[o.uop for o in trace_outputs])

    # 3. Pre-simplify (fold 0*x, 0+x, etc.)
    uret = graph_rewrite(uret, symbolic_simple, name="function_graph pre-simplify")

    # 4. Substitute known inputs with PARAMs
    call_uops = tuple(dedup([t.uop for t in trace_inputs]))
    subs = {x: x.param_like(i) for i, x in enumerate(call_uops)}
    uret = uret.substitute(subs)

    return FnGraph(
      body=uret,
      call_uops=call_uops,
      n_outputs=len(trace_outputs),
      _trace_inputs=trace_inputs,
      _trace_outputs=trace_outputs,
    )

  ##################################################################
  # JIT compilation
  ##################################################################

  @cached_property
  def jit_source(self) -> str:
    from anvil.codegen.jit import generate_jit_source

    return generate_jit_source(self.name, [self])

  @cached_property
  def _jit_module(self):
    from anvil.codegen.jit import compile_to_shared_lib, load_jit_module

    gpu = None if self.device == "CPU" else self.device.lower()
    path = compile_to_shared_lib(self.jit_source, self.name, gpu_backend=gpu)
    return load_jit_module(path)

  @cached_property
  def _jit_ws(self) -> ctypes.c_void_p:
    init_fn = getattr(self._jit_module.lib, f"jit_{self.name}_init_ws")
    init_fn.argtypes = []
    init_fn.restype = ctypes.c_void_p
    ws = init_fn()
    free_fn = getattr(self._jit_module.lib, f"jit_{self.name}_free_ws")
    free_fn.argtypes = [ctypes.c_void_p]
    free_fn.restype = None
    weakref.finalize(self, free_fn, ws)
    return ctypes.c_void_p(ws)

  @cached_property
  def _jit_call_fn(self):
    fn = getattr(self._jit_module.lib, f"jit_{self.name}_call")
    fn.argtypes = [_CPTR] * (len(self.input_bufs) + len(self.output_bufs) + 1)
    fn.restype = None
    return fn

  ##################################################################
  # python calling interface
  ##################################################################

  @cached_property
  def _has_const_output(self) -> bool:
    """True if any output is a compile-time constant (no buffer after scheduling)."""
    _, scheduled_uops = self._schedule_and_output_uops
    return any(u.base.op is Ops.CONST for u in scheduled_uops)

  # fmt: off
  @overload
  def __call__(self: "NumericalFunction[tuple[Arg], Tensor]", a0: FloatArray, /) -> FloatArray: ...
  @overload
  def __call__(self: "NumericalFunction[tuple[Arg, Arg], Tensor]", a0: FloatArray, a1: FloatArray, /) -> FloatArray: ...
  @overload
  def __call__(self: "NumericalFunction[tuple[Arg, Arg, Arg], Tensor]", a0: FloatArray, a1: FloatArray, a2: FloatArray, /) -> FloatArray: ...
  @overload
  def __call__(self: "NumericalFunction[tuple[Arg, Arg, Arg, Arg], Tensor]", a0: FloatArray, a1: FloatArray, a2: FloatArray, a3: FloatArray, /) -> FloatArray: ...
  @overload
  def __call__(self: "NumericalFunction[tuple[Arg], tuple[Tensor, Tensor]]", a0: FloatArray, /) -> tuple[FloatArray, FloatArray]: ...
  @overload
  def __call__(self: "NumericalFunction[tuple[Arg, Arg], tuple[Tensor, Tensor]]", a0: FloatArray, a1: FloatArray, /) -> tuple[FloatArray, FloatArray]: ...
  @overload
  def __call__(self: "NumericalFunction[tuple[Arg, Arg, Arg], tuple[Tensor, Tensor]]", a0: FloatArray, a1: FloatArray, a2: FloatArray, /) -> tuple[FloatArray, FloatArray]: ...
  @overload
  def __call__(self: "NumericalFunction[tuple[Arg, Arg, Arg, Arg], tuple[Tensor, Tensor]]", a0: FloatArray, a1: FloatArray, a2: FloatArray, a3: FloatArray, /) -> tuple[FloatArray, FloatArray]: ...
  # fmt: on
  def __call__(self, *args: FloatArray) -> FloatArray | tuple[FloatArray, ...]:
    if self._has_const_output:
      tensor_args = tuple(a if isinstance(a, Tensor) else Tensor(a) for a in args)
      result = self.fn(*tensor_args)
      if isinstance(result, tuple):
        return tuple(r.numpy() for r in result)
      return result.numpy()  # ty: ignore[invalid-argument-type]
    args = tuple(np.ascontiguousarray(a.numpy() if isinstance(a, Tensor) else a, dtype=_DTYPE_TO_NP[inp.dtype]) for a, inp in zip(args, self.inputs))
    outputs = tuple(np.empty(make_more(out.shape), dtype=_DTYPE_TO_NP[out.dtype]) for out in self.outputs)
    ptrs = [a.ctypes.data_as(_CPTR) for a in args] + [o.ctypes.data_as(_CPTR) for o in outputs] + [self._jit_ws]
    self._jit_call_fn(*ptrs)
    return outputs[0] if len(outputs) == 1 else outputs

  ##################################################################
  # scheduling
  ##################################################################

  @cached_property
  def _schedule_and_output_uops(self) -> tuple[list[ExecItem], tuple[UOp, ...]]:
    """Schedule the function graph and return (schedule, scheduled_output_uops).
    Uses separate tensors for scheduling so _trace_outputs remain unmodified
    (svmap/vmap need the original unscheduled function graph)."""
    fg = self.function_graph
    # create separate scheduling outputs with contiguous applied where needed:
    #  - multi-dim: ensures C-contiguous buffer layout (avoids PERMUTE view chains)
    #  - input-aliased: ensures separate output buffer (views of input can't be separate C++ params)
    #  - CONST: materializes compile-time constants into buffers
    input_base_ids: set[int] = set()
    for t in fg._trace_inputs:
      try:
        input_base_ids.add(id(t.uop.base))
      except (AttributeError, AssertionError):
        pass

    def _needs_contiguous(t: Tensor) -> bool:
      """Multi-dim outputs need contiguous to ensure C-order layout in the buffer."""
      if t.uop.op is Ops.CONTIGUOUS:
        return False
      return len(t.shape) > 1

    sched_outputs = tuple(t.contiguous() if _needs_contiguous(t) else Tensor(t.uop) for t in fg._trace_outputs)
    sink = UOp.sink(*[x.uop for x in sched_outputs])
    # pre-simplification: fold constants (0*x→0, 0+x→x, etc.) before the scheduler sees the graph.
    # this reduces graph size for JVP-heavy workloads where many tangents are CONST(0).
    old_sink = sink
    sink = graph_rewrite(sink, symbolic_simple, name="NumericalFunction pre-simplify")
    for t, old_uop, new_uop in zip(sched_outputs, old_sink.src, sink.src):
      if old_uop is not new_uop:
        t.uop = new_uop
    with Context(TRACK_MATCH_STATS=0):
      call_sink, buffer_map = transform_to_call(sink)
      call_sink = call_sink.replace(arg=CallInfo(name=self.name))
    new_sink = sink.substitute(buffer_map, name="Apply CUSTOM Buffer Map")
    for t, old_uop, new_uop in zip(sched_outputs, sink.src, new_sink.src):
      if old_uop is not new_uop:
        t.uop = new_uop
    scheduled_uops = tuple(t.uop for t in sched_outputs)
    schedule, var_vals = complete_create_schedule_with_vars(call_sink)
    assert len(var_vals) == 0
    return schedule, scheduled_uops

  @cached_property
  def schedule(self) -> list[ExecItem]:
    return self._schedule_and_output_uops[0]

  ##################################################################
  # kernel rendering
  ##################################################################

  @cached_property
  def rendered_kernels(self) -> list[RenderedKernel]:
    rks = []
    for si in self.schedule:
      bufs = cast(list[Buffer], si.bufs)
      match si.ast.op:
        case Ops.SINK:
          rks.append(render_kernel(si.ast, bufs, self._renderer))
        case Ops.COPY:
          rks.append(RenderedKernel(name="copy", src="", ast=si.ast, globals=(bufs[0], bufs[1]), ins=(bufs[1],), outs=(bufs[0],)))
        case _:
          raise NotImplementedError(f"Unsupported operation: {si.ast.op}")
    # check that the buffers that are written to by a COPY kernel are not overriden by another kernel
    for rk in rks:
      if rk.op is not Ops.COPY:
        continue
      out_buf = rk.outs[0]
      for rk2 in rks:
        if rk2 is rk:
          continue
        assert out_buf not in rk2.outs, f"kernel {rk2.name} overrides {out_buf} created by {rk.name}"
    return rks

  ##################################################################
  # buffers
  ##################################################################

  @cached_property
  def constant_bufs(self) -> More[Buffer]:
    # constant buffers are COPY sources that come from non-CPU devices (PYTHON, NPY)
    # they represent data known at codegen time (e.g. matrix constants, index arrays)
    return tuple(
      dedup(rk.ins[0] for rk in self.rendered_kernels if rk.op is Ops.COPY and rk.ins[0] not in self.input_bufs and rk.ins[0] not in self.output_bufs)
    )

  @cached_property
  def input_bufs(self) -> More[Buffer]:
    return tuple(cast(Buffer, t.uop.base.buffer) for t in self.function_graph._trace_inputs)

  @cached_property
  def output_bufs(self) -> More[Buffer]:
    # use the scheduled UOps (not the trace outputs which are restored to pre-schedule state)
    _, scheduled_uops = self._schedule_and_output_uops
    return tuple(cast(Buffer, u.base.buffer) for u in scheduled_uops)

  @cached_property
  def arg_bufs(self) -> More[Buffer]:
    return self.input_bufs + self.output_bufs

  @cached_property
  def intermediate_bufs(self) -> More[Buffer]:
    intermediates = []
    for rk in self.rendered_kernels:
      intermediates.extend(filter(lambda buf: buf not in self.arg_bufs + self.constant_bufs, rk.globals))
    return tuple(dedup(intermediates))

  ##################################################################
  # final rendering
  ##################################################################

  @cached_property
  def buf_shapes(self) -> dict[Buffer, OneOrMore[int]]:
    return {buf: self.inputs[i].shape for i, buf in enumerate(self.input_bufs)} | {
      buf: self.outputs[i].shape for i, buf in enumerate(self.output_bufs)
    }

  @cached_property
  def buf_names(self) -> dict[Buffer, str]:
    """Returns a dict with unique names for each buffer (input/output/intermediate)."""
    input_counter, output_counter, intermediate_counter, constant_counter = count(), count(), count(), count()
    return (
      {buf: f"in{next(input_counter)}" for buf in self.input_bufs}
      | {buf: f"out{next(output_counter)}" for buf in self.output_bufs}
      | {buf: f"intermediate{next(intermediate_counter)}" for buf in self.intermediate_bufs}
      | {buf: f"constant{next(constant_counter)}" for buf in self.constant_bufs}
    )

  @cached_property
  def buf_type_names(self) -> dict[Buffer, str]:
    return {buf: f"{self.buf_names[buf].upper()}_t" for buf in self.arg_bufs}

  @cached_property
  def buf_type_declarations(self) -> dict[Buffer, str]:
    shapes = {buf: render_shape(self.buf_shapes[buf]) for buf in self.arg_bufs}
    return {
      buf: f"typedef Buffer<{CustomClangRenderer.render_dtype(buf.dtype)}{(',' if len(shapes[buf]) > 0 else '') + shapes[buf]}> {self.buf_type_names[buf]};"
      for buf in self.arg_bufs
    }

  @cached_property
  def constant_buf_declarations(self) -> dict[Buffer, str]:
    return {
      buf: f"static constexpr {CustomClangRenderer.render_dtype(buf.dtype)} {self.buf_names[buf]}[{buf.size}] = {{{array_to_values_string(buf.numpy())}}};"
      for buf in self.constant_bufs
    }

  @cached_property
  def intermediate_buf_declarations(self) -> dict[Buffer, str]:
    return {
      buf: f"auto {self.buf_names[buf]} = Buffer<{CustomClangRenderer.render_dtype(buf.dtype)}, {buf.size}>{{static_cast<{CustomClangRenderer.render_dtype(buf.dtype)}*>(static_cast<void*>({self.ws_buf_name}.data + {self.intermediate_buf_offsets[buf]}))}};"
      for buf in self.intermediate_bufs
    }

  @cached_property
  def intermediate_buf_offsets(self) -> dict[Buffer, int]:
    """Returns a dict with the offset in bytes of each intermediate buffer in a global workspace vector.
    Ensures proper alignment of each buffer to 16 bytes (NEON vector size, alignment guaranteed by Buffer::alloc)
    """
    offsets = {}
    offset = 0
    alignment = 16  # NEON vector size
    for buf in self.intermediate_bufs:
      offsets[buf] = offset
      offset += (buf.nbytes + alignment - 1) & ~(alignment - 1)
    return offsets

  @cached_property
  def ws_size(self) -> int:
    """Returns the size in bytes of the global workspace vector."""
    if not self.intermediate_bufs:
      return 0
    offsets = self.intermediate_buf_offsets
    last = self.intermediate_bufs[-1]
    return offsets[last] + last.nbytes

  @property
  def ws_buf_name(self):
    return "ws"

  @property
  def ws_type_name(self) -> str:
    return "WS_t"

  @cached_property
  def ws_type_declaration(self) -> str:
    return f"typedef Buffer<{CustomClangRenderer.render_dtype(dtypes.char)},{self.ws_size}> {self.ws_type_name};"

  @cached_property
  def copy_kernel_calls(self) -> list[str]:
    return [
      f"std::memcpy({self.buf_names[rk.outs[0]]}.data, "
      f"{self.buf_names[rk.ins[0]] if rk.ins[0] in self.constant_bufs else self.buf_names[rk.ins[0]] + '.data'}, {rk.ins[0].nbytes});"
      for rk in self.rendered_kernels
      if rk.op is Ops.COPY
    ]

  @cached_property
  def kernel_calls(self) -> list[str]:
    # TODO: this is not resilient to having more than two types of RenderedKernels. make this into a proper method
    return [
      f"{rk.function_name}({', '.join(self.buf_names[buf] + '.data' for buf in rk.globals)});"
      for rk in self.rendered_kernels
      if rk.op is not Ops.COPY
    ]

  ##################################################################
  # GPU-specific properties
  ##################################################################

  @cached_property
  def gpu_kernel_sources(self) -> dict[str, str]:
    """Maps kernel function_name to its source code string (for GPU embedding)."""
    return {rk.function_name: rk.src for rk in self.rendered_kernels if rk.op is Ops.SINK}

  @cached_property
  def gpu_buf_nbytes(self) -> dict[Buffer, int]:
    """Maps each buffer to its byte size (for GPU allocation)."""
    all_bufs = set(self.input_bufs + self.output_bufs + self.intermediate_bufs + self.constant_bufs)
    return {buf: buf.nbytes for buf in all_bufs}

  @cached_property
  def gpu_kernel_buf_indices(self) -> dict[str, list[tuple[int, Buffer]]]:
    """For each SINK kernel, returns ordered (arg_index, Buffer) for GPU dispatch."""
    result = {}
    for rk in self.rendered_kernels:
      if rk.op is Ops.SINK:
        result[rk.function_name] = list(enumerate(rk.globals))
    return result

  ##################################################################
  # abstract properties from FunctionBase to implement
  ##################################################################

  @cached_property
  @override
  def codegen_constants(self) -> set[CodegenIntConstant]:
    return set()

  @cached_property
  @override
  def header_includes(self) -> set[str]:
    return set()

  @cached_property
  @override
  def header_code(self) -> str:
    if self.device != "CPU":
      if self.device == "METAL":
        for buf in self.constant_bufs:
          if buf.dtype in (dtypes.float64, dtypes.double):
            raise ValueError(
              f"Metal does not support float64: captured constant buffer has dtype {buf.dtype}. "
              "Create constants with dtype=dtypes.float32 (e.g. Tensor(..., dtype=dtypes.float32))."
            )
      template = f"numerical_function_gpu_{self.device.lower()}.j2"
      return JINJA_ENV.get_template(template).render(fn=self)
    return JINJA_ENV.get_template("numerical_function_header.j2").render(fn=self)

  @cached_property
  @override
  def source_code(self) -> str:
    if self.device != "CPU":
      return ""  # GPU: everything in header_code (single template)
    return JINJA_ENV.get_template("numerical_function_source.j2").render(fn=self)

  @cached_property
  @override
  def source_includes(self) -> set[str]:
    return {"#include <cstdlib>", "#include <cstring>"}

cuda_arch cached property

CUDA architecture string (e.g. 'sm_89') for NVRTC compilation.

outputs cached property

Output args, inferred from tracing.

buf_names cached property

Returns a dict with unique names for each buffer (input/output/intermediate).

intermediate_buf_offsets cached property

Returns a dict with the offset in bytes of each intermediate buffer in a global workspace vector. Ensures proper alignment of each buffer to 16 bytes (NEON vector size, alignment guaranteed by Buffer::alloc)

ws_size cached property

Returns the size in bytes of the global workspace vector.

gpu_kernel_sources cached property

Maps kernel function_name to its source code string (for GPU embedding).

gpu_buf_nbytes cached property

Maps each buffer to its byte size (for GPU allocation).

gpu_kernel_buf_indices cached property

For each SINK kernel, returns ordered (arg_index, Buffer) for GPU dispatch.

SparseNumericalFunction dataclass

Bases: NumericalFunction[I, Tensor]

A NumericalFunction whose output is a sparse matrix in CSC format.

The underlying function computes only the non-zero values (a 1D dense array of length nnz). The sparsity structure (indices, indptr) is fixed at construction time.

When called from Python, returns a scipy.sparse.csc_array assembled from the computed non-zero values and the stored sparsity pattern.

Parameters:

Name Type Description Default
shape Shape

(rows, cols) shape of the full sparse matrix.

required
nnz int

Number of non-zero entries.

required
indices Tensor

CSC row indices (length nnz).

required
indptr Tensor

CSC column pointers (length cols + 1).

required
Source code in src/anvil/codegen/numerical_function.py
@dataclass(frozen=True)
class SparseNumericalFunction[I: tuple[Arg, ...]](NumericalFunction[I, Tensor]):
  """A NumericalFunction whose output is a sparse matrix in CSC format.

  The underlying function computes only the non-zero values (a 1D dense array of length `nnz`).
  The sparsity structure (`indices`, `indptr`) is fixed at construction time.

  When called from Python, returns a `scipy.sparse.csc_array` assembled from the computed
  non-zero values and the stored sparsity pattern.

  Args:
    shape: (rows, cols) shape of the full sparse matrix.
    nnz: Number of non-zero entries.
    indices: CSC row indices (length `nnz`).
    indptr: CSC column pointers (length `cols + 1`).
  """

  shape: Shape
  nnz: int
  indices: Tensor
  indptr: Tensor

  @property
  def is_sparse(self):
    return True

  @cached_property
  def _indices_np(self) -> np.ndarray:
    with silent_realize():
      return self.indices.numpy().astype(np.int64)

  @cached_property
  def _indptr_np(self) -> np.ndarray:
    with silent_realize():
      return self.indptr.numpy().astype(np.int64)

  def __call__(self, *args: FloatArray) -> sp.csc_array:  # type: ignore[override]
    """Evaluate the function and return the result as a ``scipy.sparse.csc_array``."""
    data = super().__call__(*args)
    return sp.csc_array((data, self._indices_np, self._indptr_np), shape=self.shape)

  @cached_property
  def inner_indices_declaration(self) -> str:
    with silent_realize():
      return f"static constexpr int innerIndices[{int(self.indices.shape[0])}] = {{{tensor_to_values_string(self.indices)}}};"

  @cached_property
  def outer_starts_declaration(self) -> str:
    with silent_realize():
      return f"static constexpr int outerStarts[{int(self.indptr.shape[0])}] = {{{tensor_to_values_string(self.indptr)}}};"

  @cached_property
  @override
  def header_code(self) -> str:
    return JINJA_ENV.get_template("numerical_function_header.j2").render(fn=self)

  @cached_property
  @override
  def source_code(self) -> str:
    return JINJA_ENV.get_template("numerical_function_source.j2").render(fn=self)

numerical_function(name, inputs)

Decorator that creates a NumericalFunction from a plain Python function.

Parameters:

Name Type Description Default
name str

Function name for generated code.

required
inputs More[OneOrMore[int]]

Tuple of shapes, one per input (dtype defaults to float64).

required

Example::

@numerical_function("my_fn", ((3,), (2,)))
def my_fn(x: Tensor, y: Tensor) -> Tensor:
    return x + y.pad(((0, 1),))
Source code in src/anvil/codegen/numerical_function.py
def numerical_function(name: str, inputs: More[OneOrMore[int]]) -> Callable:
  """Decorator that creates a NumericalFunction from a plain Python function.

  Args:
    name: Function name for generated code.
    inputs: Tuple of shapes, one per input (dtype defaults to float64).

  Example::

      @numerical_function("my_fn", ((3,), (2,)))
      def my_fn(x: Tensor, y: Tensor) -> Tensor:
          return x + y.pad(((0, 1),))
  """

  def decorator(fn: Callable) -> NumericalFunction:
    return NumericalFunction(name, fn, tuple(Arg(shape) for shape in inputs))  # ty: ignore[no-matching-overload]

  return decorator

sparse_numerical_function(name, inputs, shape, nnz, indices, indptr)

Decorator that creates a SparseNumericalFunction with an explicit CSC sparsity pattern.

The decorated function must return a 1D Tensor of length nnz containing the non-zero values.

Source code in src/anvil/codegen/numerical_function.py
def sparse_numerical_function(
  name: str, inputs: OneOrMore[Arg], shape: Shape, nnz: int, indices: Tensor, indptr: Tensor
) -> Callable[[Callable], SparseNumericalFunction[tuple[Arg, ...]]]:
  """Decorator that creates a SparseNumericalFunction with an explicit CSC sparsity pattern.

  The decorated function must return a 1D Tensor of length `nnz` containing the non-zero values.
  """
  return lambda fn: SparseNumericalFunction(name, fn, make_more(inputs), shape, nnz, indices, indptr)

generate_module(module_name, fns, constants=None, verbosity=0)

Generate a C++ module (<module_name>.hpp + <module_name>.cpp) from a list of functions.

Functions are topologically sorted so dependencies are emitted first. All codegen constants are collected, deduplicated, and declared before any function code.

Parameters:

Name Type Description Default
module_name str

Base name for the output files.

required
fns list[FunctionBase]

Functions to include (NumericalFunction, SQPFunction, etc.).

required
constants set[CodegenIntConstant] | None

Extra codegen constants to declare beyond those from fns.

None
verbosity

If >= 1, print file names as they are generated.

0
Source code in src/anvil/codegen/common.py
def generate_module(module_name: str, fns: list[FunctionBase], constants: set[CodegenIntConstant] | None = None, verbosity=0):
  """Generate a C++ module (``<module_name>.hpp`` + ``<module_name>.cpp``) from a list of functions.

  Functions are topologically sorted so dependencies are emitted first. All codegen
  constants are collected, deduplicated, and declared before any function code.

  Args:
    module_name: Base name for the output files.
    fns: Functions to include (NumericalFunction, SQPFunction, etc.).
    constants: Extra codegen constants to declare beyond those from `fns`.
    verbosity: If >= 1, print file names as they are generated.
  """
  # perform some checks
  assert module_name != "", "Module name cannot be empty!"
  assert len(fns) != 0, "Function list cannot be empty!"
  fns_names = [fn.name for fn in fns]
  assert len(fns_names) == len(set(fns_names)), "Function names must be unique!"

  # add to the list of functions the dependencies (e.g. eq_constraints for MPCFunction)
  for fn in fns:
    fns.extend(fn.dependencies())

  # sort the functions to put all the NumericalFunctions at the beginning
  fns = toposort(fns)

  # collect all codegen constants from functions and explicit constants
  all_constants: set[CodegenIntConstant] = set()
  for fn in fns:
    for const in fn.codegen_constants:
      all_constants.update(const.all_dependencies())
  if constants is not None:
    for const in constants:
      all_constants.update(const.all_dependencies())

  # sort constants in topological order (dependencies before dependents)
  sorted_constants = toposort(list(all_constants), reverse=False)

  # generate header file
  with open((filename := module_name + ".hpp"), "w") as f:
    if verbosity >= 1:
      print("Generating header file: " + filename)
    f.write(
      JINJA_ENV.get_template("module_header.j2").render(
        module_name=module_name,
        includes=list(reduce(set.union, [fn.header_includes for fn in fns])),
        codegen_constants=sorted_constants,
        fns_code=[fn.header_code for fn in fns],
      )
    )

  with open((filename := module_name + ".cpp"), "w") as f:
    if verbosity >= 1:
      print("Generating source file: " + filename)
    f.write(
      JINJA_ENV.get_template("module_source.j2").render(
        module_name=module_name,
        includes=list(reduce(set.union, [fn.source_includes for fn in fns])),
        fns_code=[fn.source_code for fn in fns],
      )
    )