Skip to content

Multistage formulation

Nonlinear program

The MultistageFn / MultistageProblem classes implement a uniform-stage-dimension multistage NLP:

\[ \begin{align*} \min_{z_i} & \quad \tfrac{1}{2}\|r_0(z_0)\|^2 + \sum_{i=0}^{N-1} \tfrac{1}{2}\|r_d(z_i,z_{i+1})\|^2 + \sum_{i=1}^{N-1}\tfrac{1}{2}\|r(z_i)\|^2 + \tfrac{1}{2}\|r_N(z_N)\|^2 \\ \text{s.t.} & \quad a_0(z_0) = 0, \\ & \quad a_d(z_i, z_{i+1}) = 0, \quad i=0,\ldots,N-1, \\ & \quad a_N(z_N) = 0, \\ & \quad g_{\text{lb}} \leq g_0(z_0) \leq g_{\text{ub}}, \\ & \quad g_{\text{lb}} \leq g(z_i) \leq g_{\text{ub}}, \quad i=1,\ldots,N-1, \\ & \quad g_{\text{lb}} \leq g_N(z_N) \leq g_{\text{ub}}, \\ & \quad g_{d,\text{lb}} \leq g_d(z_i, z_{i+1}) \leq g_{d,\text{ub}}, \quad i=0,\ldots,N-1 \end{align*} \]

where all stage variables have the same dimension \(z_i \in \mathbb{R}^{n_z}\) and all functions take a per-stage runtime parameter \(p_i \in \mathbb{R}^{n_{p_s}}\) as trailing input (omitted above for clarity). All derivatives are taken with respect to stage variables only; \(p_i\) is fixed data.

Per-stage parameter dispatch

The solver sees a flat parameter vector \(p \in \mathbb{R}^{n_{p_s} \cdot (N+1)}\). Internally, it is reshaped to \((N+1, n_{p_s})\) and \(p_i\) is dispatched to stage \(i\)'s functions:

  • Stage functions \(f(z_i, p_i)\) receive \(p_i\).
  • Interstage functions \(f(z_i, z_{i+1}, p_i)\) receive \(p_i\) (source stage parameter).
  • The per-stage parameter dimension \(n_{p_s}\) is inferred from the component functions' last input shape (stage_param_dim property).
  • The total parameter dimension is \(n_p = n_{p_s} \cdot (N+1)\) (nparam property).
  • Users with shared (global) parameters should tile them \(N+1\) times.

Assumptions:

  • uniform stage dimension across the whole horizon
  • uniform per-stage parameter dimension across all component functions
  • stage/interstage functions are the same at all applicable stages (initial and terminal stages may use different functions)
  • all callbacks accept trailing per-stage p input: stage functions take (z, p), interstage functions take (z, znext, p)
  • costs have the nonlinear least-squares form \(l(z, p) = \tfrac{1}{2}\|r(z, p)\|^2\)
  • equality constraint RHS is always 0

Mapping to MultistageFn

Each of the three NLP components (equality constraints, inequality constraints, residuals) is represented by a MultistageFn with four optional sub-functions:

MultistageFn field Equality Inequality Residual
initial_fn \(a_0(z_0, p_0)\) \(g_0(z_0, p_0)\) \(r_0(z_0, p_0)\)
stage_fn \(g(z_i, p_i)\) \(r(z_i, p_i)\)
terminal_fn \(a_N(z_N, p_N)\) \(g_N(z_N, p_N)\) \(r_N(z_N, p_N)\)
interstage_fn \(a_d(z_i,z_{i+1}, p_i)\) \(g_d(z_i,z_{i+1}, p_i)\) \(r_d(z_i,z_{i+1}, p_i)\)

Component functions define their signature with the per-stage parameter dimension \(n_{p_s}\). The param_dim field on _MultistageFn stores \(n_{p_s}\); the total \(n_p = n_{p_s} \cdot (N+1)\) is available via nparam.

MultistageProblem composes three MultistageFn instances and exposes the SQP interface.

Global row ordering

Each MultistageFn evaluates to a vector with row ordering:

initial(z_0)
interstage(z_0,z_1), stage(z_1)
interstage(z_1,z_2), stage(z_2)
...
interstage(z_{N-2},z_{N-1}), stage(z_{N-1})
interstage(z_{N-1},z_N)
terminal(z_N)

Multistage QP subproblem (PIQP interface)

At each SQP iteration, the NLP is linearized into a block-tridiagonal QP. PIQP's multistage backend expects the form:

\[ \begin{align*} \min_{Z} & \quad \tfrac{1}{2} Z^\top P Z + c^\top Z \\ \text{s.t.} & \quad A_i z_i + B_i z_{i+1} = b_i, \quad i=0,\ldots,N-1, \\ & \quad A_N z_N = b_N, \\ & \quad d_{\text{lb},i} \leq C_i z_i + D_i z_{i+1} \leq d_{\text{ub},i}, \quad i=0,\ldots,N-1, \\ & \quad d_{\text{lb},N} \leq C_N z_N \leq d_{\text{ub},N}, \\ & \quad z_{\text{lb},i} \leq z_i \leq z_{\text{ub},i}, \quad i=0,\ldots,N \end{align*} \]

where \(Z = \begin{bmatrix} z_0^\top & z_1^\top & \cdots & z_N^\top \end{bmatrix}^\top\) and the global matrices are:

\[ P = \begin{bmatrix} P_0 & S_0 & 0 & \cdots & 0 \\ \star & P_1 & S_1 & \ddots & \vdots \\ 0 & \star & \ddots & \ddots & 0 \\ \vdots & \ddots & \ddots & P_{N-1} & S_{N-1} \\ 0 & \cdots & \cdots & \star & P_N \end{bmatrix}, \quad A = \begin{bmatrix} A_0 & B_0 & 0 & \cdots & 0 \\ 0 & A_1 & B_1 & \ddots & \vdots \\ \vdots & \ddots & \ddots & \ddots & 0 \\ \vdots & & \ddots & A_{N-1} & B_{N-1} \\ 0 & \cdots & \cdots & 0 & A_N \end{bmatrix}, \quad C = \begin{bmatrix} C_0 & D_0 & 0 & \cdots & 0 \\ 0 & C_1 & D_1 & \ddots & \vdots \\ \vdots & \ddots & \ddots & \ddots & 0 \\ \vdots & & \ddots & C_{N-1} & D_{N-1} \\ 0 & \cdots & \cdots & 0 & C_N \end{bmatrix} \]

Equality constraint blocks

From the linearized equality MultistageFn (shown for \(N=3\)):

\[ A = \left[ \begin{array}{c|c|c|c} \overset{z_0}{D a_0} & \overset{z_1}{0} & \overset{z_2}{0} & \overset{z_3}{0} \\ \hline D_z a_d & D_{z^+} a_d & 0 & 0 \\ \hline 0 & D_z a_d & D_{z^+} a_d & 0 \\ \hline 0 & 0 & D_z a_d & D_{z^+} a_d \\ \hline 0 & 0 & 0 & D a_N \end{array} \right] \]

Reading off the per-stage blocks:

\[ \begin{aligned} \left[\begin{array}{c|c} A_0 & B_0 \end{array}\right] &= \left[\begin{array}{c|c} \begin{matrix} D a_0 \\ D_z a_d \end{matrix} & \begin{matrix} 0 \\ D_{z^+} a_d \end{matrix} \end{array}\right], \\ \left[\begin{array}{c|c} A_i & B_i \end{array}\right] &= \left[\begin{array}{c|c} D_z a_d & D_{z^+} a_d \end{array}\right], \quad i=1,\ldots,N-1, \\ A_N &= D a_N \end{aligned} \]

Inequality constraint blocks

From the linearized inequality MultistageFn (shown for \(N=3\)):

\[ C = \left[ \begin{array}{c|c|c|c} \overset{z_0}{D g_0} & \overset{z_1}{0} & \overset{z_2}{0} & \overset{z_3}{0} \\ \hline D_z g_d & D_{z^+} g_d & 0 & 0 \\ \hline 0 & D g & 0 & 0 \\ \hline 0 & D_z g_d & D_{z^+} g_d & 0 \\ \hline 0 & 0 & D g & 0 \\ \hline 0 & 0 & D_z g_d & D_{z^+} g_d \\ \hline 0 & 0 & 0 & D g_N \end{array} \right] \]

Reading off the per-stage blocks:

\[ \begin{aligned} \left[\begin{array}{c|c} C_0 & D_0 \end{array}\right] &= \left[\begin{array}{c|c} \begin{matrix} D g_0 \\ D_z g_d \end{matrix} & \begin{matrix} 0 \\ D_{z^+} g_d \end{matrix} \end{array}\right], \\ \left[\begin{array}{c|c} C_i & D_i \end{array}\right] &= \left[\begin{array}{c|c} \begin{matrix} D g \\ D_z g_d \end{matrix} & \begin{matrix} 0 \\ D_{z^+} g_d \end{matrix} \end{array}\right], \quad i=1,\ldots,N-1, \\ C_N &= D g_N \end{aligned} \]

Residual Jacobian blocks

The residual MultistageFn Jacobian \(R = DR(Z)\) has the same block structure (shown for \(N=3\)):

\[ R = \left[ \begin{array}{c|c|c|c} \overset{z_0}{D r_0} & \overset{z_1}{0} & \overset{z_2}{0} & \overset{z_3}{0} \\ \hline D_z r_d & D_{z^+} r_d & 0 & 0 \\ \hline 0 & D r & 0 & 0 \\ \hline 0 & D_z r_d & D_{z^+} r_d & 0 \\ \hline 0 & 0 & D r & 0 \\ \hline 0 & 0 & D_z r_d & D_{z^+} r_d \\ \hline 0 & 0 & 0 & D r_N \end{array} \right] \]

Cost gradient

The cost gradient \(c = DR(Z)^\top R(Z)\) is computed per-block:

\[ c = \begin{bmatrix} Dr_0(z_0)^\top r_0(z_0) + D_z r_d(z_0,z_1)^\top r_d(z_0,z_1) \\ D r(z_1)^\top r(z_1) + D_{z^+} r_d(z_0, z_1)^\top r_d(z_0, z_1) + D_z r_d(z_1,z_2)^\top r_d(z_1, z_2) \\ \vdots \\ D r(z_{N-1})^\top r(z_{N-1}) + D_{z^+} r_d(z_{N-2}, z_{N-1})^\top r_d(z_{N-2}, z_{N-1}) + D_z r_d(z_{N-1},z_N)^\top r_d(z_{N-1}, z_N) \\ D r_N(z_N)^\top r_N(z_N) + D_{z^+} r_d(z_{N-1},z_N)^\top r_d(z_{N-1},z_N) \end{bmatrix} \]

Gauss-Newton Hessian

The Gauss-Newton Hessian \(P = DR(Z)^\top DR(Z)\) is block-tridiagonal. The implementation computes it per-block rather than forming the full Jacobian.

Diagonal blocks:

\[ \begin{aligned} P_0 &= Dr_0^\top Dr_0 + D_z r_d^\top D_z r_d \\ P_i &= D_{z^+} r_d^\top D_{z^+} r_d + Dr^\top Dr + D_z r_d^\top D_z r_d, \quad i=1,\ldots,N-1 \\ P_N &= D_{z^+} r_d^\top D_{z^+} r_d + Dr_N^\top Dr_N \end{aligned} \]

Off-diagonal blocks:

\[ S_i = D_z r_d(z_i, z_{i+1})^\top D_{z^+} r_d(z_i, z_{i+1}), \quad i=0,\ldots,N-1 \]

Each diagonal block is computed with spsyrk (sparse \(J^\top J\)) on the stacked per-argnum Jacobians. Each off-diagonal block is computed with spmm_tn (sparse \(A^\top B\)). Both are batched over the horizon with svmap.

General scalar cost (_MultistageCostFn)

For general (non-least-squares) cost functions the NLP objective is:

\[ L(Z; p) = l_0(z_0, p_0) + \sum_{i=0}^{N-1} l_d(z_i,z_{i+1}, p_i) + \sum_{i=1}^{N-1}l(z_i, p_i) + l_N(z_N, p_N) \]

where \(l_0, l, l_N : \mathbb{R}^{n_z} \times \mathbb{R}^{n_{p_s}} \to \mathbb{R}\) and \(l_d : \mathbb{R}^{n_z} \times \mathbb{R}^{n_z} \times \mathbb{R}^{n_{p_s}} \to \mathbb{R}\) are arbitrary smooth scalars (all taking per-stage \(p_i\) as trailing input). This is implemented by _MultistageCostFn.

Cost gradient

The gradient \(\nabla L(Z)\) decomposes block-wise. Since interstage functions use the source stage's parameter (\(l_d(z_i, z_{i+1}, p_i)\) uses \(p_i\)), the per-block gradient functions receive both the current and previous stage parameters where applicable:

\[ \begin{aligned} \bar{g}_0(z, z^+, p) &= \nabla l_0(z, p) + \nabla_z l_d(z, z^+, p) \\ \bar{g}(z^-, z, z^+, p^-, p) &= \nabla l(z, p) + \nabla_{z^+} l_d(z^-, z, p^-) + \nabla_z l_d(z, z^+, p) \\ \bar{g}_N(z^-, z, p^-, p) &= \nabla l_N(z, p) + \nabla_{z^+} l_d(z^-, z, p^-) \end{aligned} \]

Note that \(\bar{g}_0\) only needs \(p_0\) (both \(l_0\) and \(l_d(z_0, z_1)\) use \(p_0\)). The middle and terminal blocks need \(p^- = p_{i-1}\) for the previous interstage term and \(p = p_i\) for the current stage/interstage terms.

The global gradient is:

\[ \nabla L(Z) = \begin{bmatrix} \bar{g}_0(z_0, z_1, p_0) \\ \bar{g}(z_0, z_1, z_2, p_0, p_1) \\ \vdots \\ \bar{g}(z_{N-2}, z_{N-1}, z_N, p_{N-2}, p_{N-1}) \\ \bar{g}_N(z_{N-1}, z_N, p_{N-1}, p_N) \end{bmatrix} \]

Cost Hessian

The Hessian \(\nabla^2 L(Z)\) is block-tridiagonal. Reading off the blocks from the Jacobian of the gradient:

\[ \nabla^2 L(Z) = \begin{bmatrix} D_z \bar{g}_0 & D_{z^+} \bar{g}_0 & 0 & \cdots & 0 \\ \star & D_z \bar{g} & D_{z^+} \bar{g} & \ddots & \vdots \\ 0 & \star & D_z \bar{g} & \ddots & 0 \\ \vdots & \ddots & \ddots & \ddots & D_{z^+} \bar{g} \\ 0 & \cdots & 0 & \star & D_z \bar{g}_N \end{bmatrix} \]

Diagonal blocks \(P_i = D_z \bar{g}_i\) are symmetric (Hessians of scalar functions). Upper off-diagonal blocks \(S_i = D_{z^+} \bar{g}_i\) are plain Jacobians. The lower off-diagonal blocks are not computed — by symmetry they equal \(S_i^\top\).

Expanding the off-diagonal blocks:

\[ S_0 = D_{z^+} \bar{g}_0(z_0, z_1, p_0) = \nabla^2_{z\,z^+} l_d(z_0, z_1, p_0), \qquad S_i = D_{z^+} \bar{g}(z_{i-1}, z_i, z_{i+1}, p_{i-1}, p_i) = \nabla^2_{z\,z^+} l_d(z_i, z_{i+1}, p_i) \]

so all off-diagonal blocks share the same functional form \(\nabla^2_{z\,z^+} l_d\), computed via a single spjacobian of \(\nabla_z l_d\) with respect to its second argument. Each off-diagonal block receives the corresponding per-stage parameter \(p_i\).

Implementation details

Gradient components (cached properties, computed once):

Property Function
_grad_l0 gradient(initial_fn, argnum=0)
_grad_l gradient(stage_fn, argnum=0)
_grad_lN gradient(terminal_fn, argnum=0)
_grad_ld_z gradient(interstage_fn, argnum=0)
_grad_ld_znext gradient(interstage_fn, argnum=1)

Per-block gradient functions (cached NumericalFunction objects wrapping sums of the above):

Property Signature Body
_g_bar_0 \((z, z^+, p) \to \mathbb{R}^{n_z}\) grad_l0(z, p) + grad_ld_z(z, znext, p)
_g_bar \((z^-, z, z^+, p^-, p) \to \mathbb{R}^{n_z}\) grad_l(z, p) + grad_ld_znext(zprev, z, pprev) + grad_ld_z(z, znext, p)
_g_bar_N \((z^-, z, p^-, p) \to \mathbb{R}^{n_z}\) grad_lN(z, p) + grad_ld_znext(zprev, z, pprev)

Each _g_bar_* is a thin @numerical_function wrapper that sums the relevant components; None components are omitted. The p/pprev parameters are threaded through calls but never differentiated. The pprev parameter is needed because interstage terms \(l_d(z_{i-1}, z_i, p_{i-1})\) use the previous stage's parameter.

Hessian block functions (cached SparseNumericalFunction objects):

Property Derivation spjacobian call
_hess_diag0 \(P_0 = D_z \bar{g}_0\), symmetric spjacobian(_g_bar_0, argnum=0, symmetric=True)
_hess_diag_inter \(P_i = D_z \bar{g}\), symmetric, \(i=1\ldots N-1\) spjacobian(_g_bar, argnum=1, symmetric=True)
_hess_diagN \(P_N = D_z \bar{g}_N\), symmetric spjacobian(_g_bar_N, argnum=1, symmetric=True)
_hess_offdiag \(S_i = \nabla^2_{z\,z^+} l_d\), shared for all \(i\) spjacobian(_grad_ld_z, argnum=1)

symmetric=True triggers star coloring so only the upper-triangular entries are computed and stored.

Build methods and their vectorization:

Method Output Boundary blocks (single call) Middle blocks (vmapped)
build_cost_fn Scalar \(L(Z)\) l0(z_0, p_0), lN(z_N, p_N) dvmap(ls)(stages[1:N], params[1:N]), dvmap(ld)(stages[:N], stages[1:], params[:N])
build_cost_grad_fn Dense \(\nabla L(Z)\) _g_bar_0(z_0, z_1, p_0), _g_bar_N(z_{N-1}, z_N, p_{N-1}, p_N) svmap(_g_bar, N-1)(stages[:-2], stages[1:-1], stages[2:], params[:-2], params[1:-1])
build_hessian_fn Sparse upper-tri \(\nabla^2 L\) _hess_diag0(z_0, z_1, p_0), _hess_diagN(z_{N-1}, z_N, p_{N-1}, p_N) svmap(_hess_diag_inter, N-1), svmap(_hess_offdiag, N)

All build methods reshape the flat parameter vector \(p\) to \((N+1, n_{p_s})\) and dispatch per-stage slices. The dvmap/svmap calls use in_axes=(0, ..., 0) to batch over both stage variables and parameters (no None broadcasting).

Assembly of build_hessian_fn follows the same place_blocks pattern as build_gn_hessian_fn: block offsets are precomputed from the global upper-tri CSC sparsity and stamped at runtime via custom_kernel.

Implementation structure

The key methods on _MultistageFn:

Method Output Strategy
build_fn Dense evaluation Reshape p to \((N+1, n_{p_s})\), dvmap over stage/interstage with per-stage params, interleave results
build_jac_fn Sparse CSC Jacobian Per-block spjacobian + svmap with per-stage params + place_blocks assembly
build_cost_fn Scalar \(\tfrac{1}{2}\|r\|^2\) dvmap over blocks with per-stage params, sum of squares
build_cost_grad_fn Dense gradient \(J^\top r\) Per-block spmv_t with pprev/p dispatch + svmap
build_gn_hessian_fn Sparse upper-tri \(J^\top J\) Per-block spsyrk/spmm_tn with pprev/p dispatch + svmap + place_blocks

The key methods on _MultistageCostFn:

Method Output Strategy
build_cost_fn Scalar \(L(Z)\) Reshape p, dvmap over stage/interstage with per-stage params, scalar sum
build_cost_grad_fn Dense gradient \(\nabla L\) Reshape p, per-block _g_bar_* with pprev/p dispatch + svmap over middle blocks
build_hessian_fn Sparse upper-tri \(\nabla^2 L\) Reshape p, per-block spjacobian of _g_bar_* with pprev/p + svmap + place_blocks