jax_mppi

jax_mppi is a functional, JIT-compilable port of the pytorch_mppi library to JAX. It implements Model Predictive Path Integral (MPPI) control with a focus on performance and composability.
Design Philosophy
This library embraces JAX’s functional paradigm:
- Pure Functions: Core logic is implemented as pure functions
command(state, mppi_state) -> (action, mppi_state). - Dataclass State: State is held in
jax.tree_util.register_dataclasscontainers, allowing easy integration withjit,vmap, andgrad. - No Side Effects: Unlike the PyTorch version, there is no mutable
self. State transitions are explicit.
Key Features
- Core MPPI: Robust implementation of the standard MPPI algorithm.
- Smooth MPPI (SMPPI): Maintains action sequences and smoothness costs for better trajectory generation.
- Kernel MPPI (KMPPI): Uses kernel interpolation for control points, reducing the parameter space.
- Autotuning: Built-in hyperparameter optimization using CMA-ES, Ray Tune, and Quality Diversity.
- CUDA/C++ Backend: High-performance implementations of all controllers in CUDA/C++17, exposed to Python via `nanobind`.
- JAX Integration:
jax.vmapfor efficient batch processing.jax.lax.scanfor fast horizon loops.- Fully compatible with JIT compilation for high-performance control loops.
Installation
# Clone the repository
git clone https://github.com/yourusername/jax_mppi.git
cd jax_mppi
# Install dependencies
pip install -e .Usage
import jax
import jax.numpy as jnp
from jax_mppi import mppi
# Define dynamics and cost functions
def dynamics(state, action):
# Your dynamics model here
return state + action
def running_cost(state, action):
# Your cost function here
return jnp.sum(state**2) + jnp.sum(action**2)
# Create configuration and initial state
config, mppi_state = mppi.create(
nx=4, nu=2,
noise_sigma=jnp.eye(2) * 0.1,
horizon=20,
lambda_=1.0
)
# Control loop
key = jax.random.PRNGKey(0)
current_obs = jnp.zeros(4)
# JIT compile the command function for performance
jitted_command = jax.jit(mppi.command, static_argnames=['dynamics', 'running_cost'])
for _ in range(100):
key, subkey = jax.random.split(key)
action, mppi_state = jitted_command(
config,
mppi_state,
current_obs,
dynamics=dynamics,
running_cost=running_cost
)
# Apply action to environment...Project Structure
jax_mppi/
├── src/jax_mppi/
│ ├── mppi.py # Core MPPI implementation
│ ├── smppi.py # Smooth MPPI variant
│ ├── kmppi.py # Kernel MPPI variant
│ ├── types.py # Type definitions
│ ├── autotune.py # Autotuning core & CMA-ES
│ ├── autotune_global.py # Ray Tune integration
│ └── autotune_qd.py # Quality Diversity optimization
├── examples/
│ ├── pendulum.py # Pendulum environment example
│ ├── autotune_basic.py # Basic autotuning example
│ ├── autotune_pendulum.py # Autotuning pendulum
│ └── smooth_comparison.py # Comparison of MPPI variants
└── tests/ # Unit and integration tests
Roadmap
The development is structured in phases:
- Core MPPI: Basic implementation with JAX parity.
- Integration: Pendulum example and verification.
- Smooth MPPI: Implementation of smoothness constraints.
- Kernel MPPI: Kernel-based control parameterization.
- Comparisons: Benchmarking and visual comparisons.
- Autotuning: Parameter optimization using CMA-ES, Ray Tune, and QD.
Credits
This project is a direct port of pytorch_mppi. We aim to maintain parity with the original implementation while leveraging JAX’s unique features for performance and flexibility.