Multistage formulation¶
Nonlinear program¶
The MultistageFn / MultistageProblem classes implement a uniform-stage-dimension multistage NLP:
where all stage variables have the same dimension \(z_i \in \mathbb{R}^{n_z}\) and all functions take a per-stage runtime parameter \(p_i \in \mathbb{R}^{n_{p_s}}\) as trailing input (omitted above for clarity). All derivatives are taken with respect to stage variables only; \(p_i\) is fixed data.
Per-stage parameter dispatch¶
The solver sees a flat parameter vector \(p \in \mathbb{R}^{n_{p_s} \cdot (N+1)}\). Internally, it is reshaped to \((N+1, n_{p_s})\) and \(p_i\) is dispatched to stage \(i\)'s functions:
- Stage functions \(f(z_i, p_i)\) receive \(p_i\).
- Interstage functions \(f(z_i, z_{i+1}, p_i)\) receive \(p_i\) (source stage parameter).
- The per-stage parameter dimension \(n_{p_s}\) is inferred from the component functions' last input shape (
stage_param_dimproperty). - The total parameter dimension is \(n_p = n_{p_s} \cdot (N+1)\) (
nparamproperty). - Users with shared (global) parameters should tile them \(N+1\) times.
Assumptions:
- uniform stage dimension across the whole horizon
- uniform per-stage parameter dimension across all component functions
- stage/interstage functions are the same at all applicable stages (initial and terminal stages may use different functions)
- all callbacks accept trailing per-stage
pinput: stage functions take(z, p), interstage functions take(z, znext, p) - costs have the nonlinear least-squares form \(l(z, p) = \tfrac{1}{2}\|r(z, p)\|^2\)
- equality constraint RHS is always 0
Mapping to MultistageFn¶
Each of the three NLP components (equality constraints, inequality constraints, residuals) is represented by a MultistageFn with four optional sub-functions:
MultistageFn field |
Equality | Inequality | Residual |
|---|---|---|---|
initial_fn |
\(a_0(z_0, p_0)\) | \(g_0(z_0, p_0)\) | \(r_0(z_0, p_0)\) |
stage_fn |
— | \(g(z_i, p_i)\) | \(r(z_i, p_i)\) |
terminal_fn |
\(a_N(z_N, p_N)\) | \(g_N(z_N, p_N)\) | \(r_N(z_N, p_N)\) |
interstage_fn |
\(a_d(z_i,z_{i+1}, p_i)\) | \(g_d(z_i,z_{i+1}, p_i)\) | \(r_d(z_i,z_{i+1}, p_i)\) |
Component functions define their signature with the per-stage parameter dimension \(n_{p_s}\). The param_dim field on _MultistageFn stores \(n_{p_s}\); the total \(n_p = n_{p_s} \cdot (N+1)\) is available via nparam.
MultistageProblem composes three MultistageFn instances and exposes the SQP interface.
Global row ordering¶
Each MultistageFn evaluates to a vector with row ordering:
initial(z_0)
interstage(z_0,z_1), stage(z_1)
interstage(z_1,z_2), stage(z_2)
...
interstage(z_{N-2},z_{N-1}), stage(z_{N-1})
interstage(z_{N-1},z_N)
terminal(z_N)
Multistage QP subproblem (PIQP interface)¶
At each SQP iteration, the NLP is linearized into a block-tridiagonal QP. PIQP's multistage backend expects the form:
where \(Z = \begin{bmatrix} z_0^\top & z_1^\top & \cdots & z_N^\top \end{bmatrix}^\top\) and the global matrices are:
Equality constraint blocks¶
From the linearized equality MultistageFn (shown for \(N=3\)):
Reading off the per-stage blocks:
Inequality constraint blocks¶
From the linearized inequality MultistageFn (shown for \(N=3\)):
Reading off the per-stage blocks:
Residual Jacobian blocks¶
The residual MultistageFn Jacobian \(R = DR(Z)\) has the same block structure (shown for \(N=3\)):
Cost gradient¶
The cost gradient \(c = DR(Z)^\top R(Z)\) is computed per-block:
Gauss-Newton Hessian¶
The Gauss-Newton Hessian \(P = DR(Z)^\top DR(Z)\) is block-tridiagonal. The implementation computes it per-block rather than forming the full Jacobian.
Diagonal blocks:
Off-diagonal blocks:
Each diagonal block is computed with spsyrk (sparse \(J^\top J\)) on the stacked per-argnum Jacobians. Each off-diagonal block is computed with spmm_tn (sparse \(A^\top B\)). Both are batched over the horizon with svmap.
General scalar cost (_MultistageCostFn)¶
For general (non-least-squares) cost functions the NLP objective is:
where \(l_0, l, l_N : \mathbb{R}^{n_z} \times \mathbb{R}^{n_{p_s}} \to \mathbb{R}\) and \(l_d : \mathbb{R}^{n_z} \times \mathbb{R}^{n_z} \times \mathbb{R}^{n_{p_s}} \to \mathbb{R}\) are arbitrary smooth scalars (all taking per-stage \(p_i\) as trailing input). This is implemented by _MultistageCostFn.
Cost gradient¶
The gradient \(\nabla L(Z)\) decomposes block-wise. Since interstage functions use the source stage's parameter (\(l_d(z_i, z_{i+1}, p_i)\) uses \(p_i\)), the per-block gradient functions receive both the current and previous stage parameters where applicable:
Note that \(\bar{g}_0\) only needs \(p_0\) (both \(l_0\) and \(l_d(z_0, z_1)\) use \(p_0\)). The middle and terminal blocks need \(p^- = p_{i-1}\) for the previous interstage term and \(p = p_i\) for the current stage/interstage terms.
The global gradient is:
Cost Hessian¶
The Hessian \(\nabla^2 L(Z)\) is block-tridiagonal. Reading off the blocks from the Jacobian of the gradient:
Diagonal blocks \(P_i = D_z \bar{g}_i\) are symmetric (Hessians of scalar functions). Upper off-diagonal blocks \(S_i = D_{z^+} \bar{g}_i\) are plain Jacobians. The lower off-diagonal blocks are not computed — by symmetry they equal \(S_i^\top\).
Expanding the off-diagonal blocks:
so all off-diagonal blocks share the same functional form \(\nabla^2_{z\,z^+} l_d\), computed via a single spjacobian of \(\nabla_z l_d\) with respect to its second argument. Each off-diagonal block receives the corresponding per-stage parameter \(p_i\).
Implementation details¶
Gradient components (cached properties, computed once):
| Property | Function |
|---|---|
_grad_l0 |
gradient(initial_fn, argnum=0) |
_grad_l |
gradient(stage_fn, argnum=0) |
_grad_lN |
gradient(terminal_fn, argnum=0) |
_grad_ld_z |
gradient(interstage_fn, argnum=0) |
_grad_ld_znext |
gradient(interstage_fn, argnum=1) |
Per-block gradient functions (cached NumericalFunction objects wrapping sums of the above):
| Property | Signature | Body |
|---|---|---|
_g_bar_0 |
\((z, z^+, p) \to \mathbb{R}^{n_z}\) | grad_l0(z, p) + grad_ld_z(z, znext, p) |
_g_bar |
\((z^-, z, z^+, p^-, p) \to \mathbb{R}^{n_z}\) | grad_l(z, p) + grad_ld_znext(zprev, z, pprev) + grad_ld_z(z, znext, p) |
_g_bar_N |
\((z^-, z, p^-, p) \to \mathbb{R}^{n_z}\) | grad_lN(z, p) + grad_ld_znext(zprev, z, pprev) |
Each _g_bar_* is a thin @numerical_function wrapper that sums the relevant components; None components are omitted. The p/pprev parameters are threaded through calls but never differentiated. The pprev parameter is needed because interstage terms \(l_d(z_{i-1}, z_i, p_{i-1})\) use the previous stage's parameter.
Hessian block functions (cached SparseNumericalFunction objects):
| Property | Derivation | spjacobian call |
|---|---|---|
_hess_diag0 |
\(P_0 = D_z \bar{g}_0\), symmetric | spjacobian(_g_bar_0, argnum=0, symmetric=True) |
_hess_diag_inter |
\(P_i = D_z \bar{g}\), symmetric, \(i=1\ldots N-1\) | spjacobian(_g_bar, argnum=1, symmetric=True) |
_hess_diagN |
\(P_N = D_z \bar{g}_N\), symmetric | spjacobian(_g_bar_N, argnum=1, symmetric=True) |
_hess_offdiag |
\(S_i = \nabla^2_{z\,z^+} l_d\), shared for all \(i\) | spjacobian(_grad_ld_z, argnum=1) |
symmetric=True triggers star coloring so only the upper-triangular entries are computed and stored.
Build methods and their vectorization:
| Method | Output | Boundary blocks (single call) | Middle blocks (vmapped) |
|---|---|---|---|
build_cost_fn |
Scalar \(L(Z)\) | l0(z_0, p_0), lN(z_N, p_N) |
dvmap(ls)(stages[1:N], params[1:N]), dvmap(ld)(stages[:N], stages[1:], params[:N]) |
build_cost_grad_fn |
Dense \(\nabla L(Z)\) | _g_bar_0(z_0, z_1, p_0), _g_bar_N(z_{N-1}, z_N, p_{N-1}, p_N) |
svmap(_g_bar, N-1)(stages[:-2], stages[1:-1], stages[2:], params[:-2], params[1:-1]) |
build_hessian_fn |
Sparse upper-tri \(\nabla^2 L\) | _hess_diag0(z_0, z_1, p_0), _hess_diagN(z_{N-1}, z_N, p_{N-1}, p_N) |
svmap(_hess_diag_inter, N-1), svmap(_hess_offdiag, N) |
All build methods reshape the flat parameter vector \(p\) to \((N+1, n_{p_s})\) and dispatch per-stage slices. The dvmap/svmap calls use in_axes=(0, ..., 0) to batch over both stage variables and parameters (no None broadcasting).
Assembly of build_hessian_fn follows the same place_blocks pattern as build_gn_hessian_fn: block offsets are precomputed from the global upper-tri CSC sparsity and stamped at runtime via custom_kernel.
Implementation structure¶
The key methods on _MultistageFn:
| Method | Output | Strategy |
|---|---|---|
build_fn |
Dense evaluation | Reshape p to \((N+1, n_{p_s})\), dvmap over stage/interstage with per-stage params, interleave results |
build_jac_fn |
Sparse CSC Jacobian | Per-block spjacobian + svmap with per-stage params + place_blocks assembly |
build_cost_fn |
Scalar \(\tfrac{1}{2}\|r\|^2\) | dvmap over blocks with per-stage params, sum of squares |
build_cost_grad_fn |
Dense gradient \(J^\top r\) | Per-block spmv_t with pprev/p dispatch + svmap |
build_gn_hessian_fn |
Sparse upper-tri \(J^\top J\) | Per-block spsyrk/spmm_tn with pprev/p dispatch + svmap + place_blocks |
The key methods on _MultistageCostFn:
| Method | Output | Strategy |
|---|---|---|
build_cost_fn |
Scalar \(L(Z)\) | Reshape p, dvmap over stage/interstage with per-stage params, scalar sum |
build_cost_grad_fn |
Dense gradient \(\nabla L\) | Reshape p, per-block _g_bar_* with pprev/p dispatch + svmap over middle blocks |
build_hessian_fn |
Sparse upper-tri \(\nabla^2 L\) | Reshape p, per-block spjacobian of _g_bar_* with pprev/p + svmap + place_blocks |