classDiagram
class MPPIController~Dynamics, Cost~ {
+compute(state)
+get_action()
-rollout_kernel()
-update_kernel()
}
class Dynamics {
<<Interface>>
+step(x, u, x_next, dt)*
}
class Cost {
<<Interface>>
+compute(x, u, t)*
+terminal_cost(x)*
}
MPPIController ..> Dynamics : Uses
MPPIController ..> Cost : Uses
Architecture
Design Philosophy
cuda-mppi is built for maximum throughput and minimal latency. To achieve this, we avoid virtual function calls and dynamic memory allocation inside the control loop. Instead, we rely on C++ Templates and Static Polymorphism.
Core Components
The architecture revolves around three main concepts that the user must define (or generate via JIT):
- Dynamics: Defines how the system state evolves \(x_{t+1} = f(x_t, u_t)\).
- Cost: Defines the immediate and terminal cost of trajectories \(J = \sum c(x, u) + \phi(x_T)\).
- Controller: The MPPI solver that orchestrates sampling and optimization.
Class Diagram
Memory Management
All heavy memory allocations (rollout buffers, noise arrays) are performed once during the controller’s initialization (constructor).
- Device Memory: Allocated via
cudaMallocand held for the lifetime of the controller. - Host Memory: Minimal host memory is used, primarily for copying the initial state and retrieving the optimal action.
Rollout Kernel
The heart of the library is the rollout_kernel (in include/mppi/core/kernels.cuh).
- Parallelism: One CUDA thread handles one complete trajectory sample (or a small group of samples).
- Register Usage: State variables are kept in registers (
float x[NX]) to minimize global memory access during the integration loop. - Shared Memory: Not heavily used for standard rollouts to maximize occupancy, but utilized in KMPPI for kernel matrix operations.
// Simplified Kernel Logic
for (int t = 0; t < Horizon; ++t) {
// 1. Compute Control
float u_val = u_nom[t] + noise[k, t];
// 2. Step Dynamics (Inlined)
dynamics.step(x_curr, u_val, x_next, dt);
// 3. Accumulate Cost
total_cost += cost.compute(x_curr, u_val);
// 4. Update Register State
x_curr = x_next;
}