Autotuning Guide
JAX-MPPI includes a robust autotuning framework to optimize MPPI hyperparameters (like temperature \(\lambda\), noise covariance \(\Sigma\), and planning horizon). The framework supports multiple optimization strategies, including CMA-ES, Ray Tune, and Quality Diversity (QD) methods.
Overview
The autotuning process involves three main components:
- Tunable Parameters: Parameters you want to optimize (e.g.,
LambdaParameter,NoiseSigmaParameter). - Evaluation Function: A function that runs MPPI with a specific configuration and returns a cost (and optionally other metrics).
- Optimizer: The algorithm used to search for the best parameters (e.g.,
CMAESOpt).
Basic Usage (CMA-ES)
The autotune module provides a simple interface for CMA-ES optimization.
import jax.numpy as jnp
from jax_mppi import mppi, autotune
# 1. Setup MPPI
config, state = mppi.create(...)
holder = autotune.ConfigStateHolder(config, state)
# 2. Define evaluation
def evaluate():
# Run simulation with holder.config and holder.state
# Calculate performance cost
return autotune.EvaluationResult(mean_cost=cost, ...)
# 3. Create Tuner
tuner = autotune.Autotune(
params_to_tune=[
autotune.LambdaParameter(holder, min_value=0.1),
autotune.NoiseSigmaParameter(holder, min_value=0.1),
],
evaluate_fn=evaluate,
optimizer=autotune.CMAESOpt(population=10),
)
# 4. Optimize
best_result = tuner.optimize_all(iterations=30)
print(f"Best parameters: {best_result.params}")See examples/autotuning/basic.py and examples/autotuning/pendulum.py for complete running examples.
Advanced Usage
Global Optimization with Ray Tune
For more complex search spaces or when you want to use advanced schedulers and search algorithms (like HyperOpt or Bayesian Optimization), use autotune_global.
Note: Requires
ray[tune],hyperopt, andbayesian-optimization.
from ray import tune
from jax_mppi import autotune_global as autog
# Define search space using Ray Tune's API
params = [
autog.GlobalLambdaParameter(holder, search_space=tune.loguniform(0.1, 10.0)),
autog.GlobalNoiseSigmaParameter(holder, search_space=tune.uniform(0.1, 2.0)),
]
tuner = autog.AutotuneGlobal(
params_to_tune=params,
evaluate_fn=evaluate,
optimizer=autog.RayOptimizer(),
)
best = tuner.optimize_all(iterations=100)Quality Diversity (QD)
To find a diverse set of high-performing parameters (e.g., finding parameters that work well for different environments or behavioral descriptors), use autotune_qd.
from jax_mppi import autotune_qd
tuner = autotune.Autotune(
params_to_tune=[...],
evaluate_fn=evaluate,
optimizer=autotune_qd.CMAMEOpt(population=20, bins=10),
)Tunable Parameters
The framework supports tuning the following parameters out-of-the-box:
LambdaParameter: MPPI temperature (\(\lambda\)).NoiseSigmaParameter: Exploration noise covariance diagonal.MuParameter: Exploration noise mean.HorizonParameter: Planning horizon length (resizes internal buffers automatically).
You can also define custom parameters by subclassing TunableParameter.
Mathematical Formulation
This section details the mathematical foundations of the autotuning algorithms available in jax_mppi.
Hyperparameter Optimization Problem
The goal of autotuning is to find the optimal set of hyperparameters \(\theta\) (e.g., temperature \(\lambda\), noise covariance \(\Sigma\), horizon \(H\)) that minimizes the expected cost of the control task. We formulate this as an optimization problem:
[ ^* = _{} () ]
where \(\Theta\) is the admissible hyperparameter space, and the objective function \(\mathcal{J}(\theta)\) is the expected cumulative cost of the closed-loop system under the MPPI controller parameterized by \(\theta\):
[ () = {{}()} ]
Here, \(\tau = \{(\mathbf{x}_0, \mathbf{u}_0), \dots \}\) represents a trajectory rollout, and \(c(\mathbf{x}, \mathbf{u})\) is the task cost function. Since \(\mathcal{J}(\theta)\) is typically non-convex and noisy (due to the stochastic nature of MPPI and the environment), we employ derivative-free optimization methods.
CMA-ES (Covariance Matrix Adaptation Evolution Strategy)
CMA-ES is a state-of-the-art evolutionary algorithm for continuous optimization. It models the population of candidate solutions using a multivariate normal distribution \(\mathcal{N}(\mathbf{m}, \sigma^2 \mathbf{C})\).
The algorithm proceeds in generations \(g\). At each generation:
Sampling: We sample \(\lambda_{pop}\) candidate parameters \(\theta_i\) (offspring): [ i ^{(g)} + ^{(g)} (, ^{(g)}) i = 1, , {pop} ]
Evaluation: Each candidate \(\theta_i\) is evaluated by running an MPPI simulation to estimate \(\mathcal{J}(\theta_i)\).
Selection and Recombination: The candidates are sorted by their cost \(\mathcal{J}(\theta_i)\). The top \(\mu\) candidates (parents) are selected to update the mean: [ ^{(g+1)} = {i=1}^{} w_i {i:_{pop}} ] where \(w_i\) are positive weights summing to 1, and \(\theta_{i:\lambda_{pop}}\) denotes the \(i\)-th best candidate.
Covariance Adaptation: The covariance matrix \(\mathbf{C}^{(g)}\) is updated to increase the likelihood of successful steps. This involves two paths:
- Rank-1 Update: Uses the evolution path \(\mathbf{p}_c\) to exploit correlations between consecutive steps.
- Rank-\(\mu\) Update: Uses the variance of the successful steps. [ ^{(g+1)} = (1 - c_1 - c_) ^{(g)} + c_1 c c^T + c{i=1}^{} w_i ({i:{pop}} - ^{(g)})({i:{pop}} - {(g)})T / ^{(g)2} ]
Step Size Control: The global step size \(\sigma^{(g)}\) is updated using the conjugate evolution path \(\mathbf{p}_\sigma\) to control the overall scale of the distribution.
Quality Diversity with CMA-ME
Quality Diversity (QD) algorithms optimize for a set of high-performing solutions that are diverse with respect to a user-defined measure. jax_mppi uses CMA-ME (Covariance Matrix Adaptation MAP-Elites), which combines the search power of CMA-ES with the archive maintenance of MAP-Elites.
Problem Formulation
We seek to find a collection of parameters \(P = \{\theta_1, \dots, \theta_N\}\) that maximize the quality function \(f(\theta) = -\mathcal{J}(\theta)\) while covering the behavior space \(\mathcal{B}\).
Let \(\mathbf{b}(\theta): \Theta \to \mathcal{B}\) be a function mapping parameters to a behavior descriptor (e.g., control smoothness, risk sensitivity).
MAP-Elites Archive
The behavior space \(\mathcal{B}\) is discretized into a grid of cells (the archive \(\mathcal{A}\)). Each cell \(\mathcal{A}_{\mathbf{z}}\) stores the best solution found so far that maps to that cell index \(\mathbf{z}\):
[ {} = {: (()) = } f() ]
CMA-ME Algorithm
CMA-ME maintains a set of emitters, which are instances of CMA-ES optimizing for improvement in the archive.
- Emission: An emitter samples a candidate \(\theta\) from its distribution \(\mathcal{N}(\mathbf{m}, \sigma^2 \mathbf{C})\).
- Evaluation: Calculate quality \(f(\theta)\) and behavior \(\mathbf{b}(\theta)\).
- Archive Update:
- Determine the cell index \(\mathbf{z} = \text{index}(\mathbf{b}(\theta))\).
- If cell \(\mathcal{A}_{\mathbf{z}}\) is empty or \(f(\theta) > f(\mathcal{A}_{\mathbf{z}})\), replace the occupant with \(\theta\).
- Calculate the “improvement” value \(\Delta\) (e.g., \(f(\theta) - f(\mathcal{A}_{\mathbf{z}}^{old})\)).
- Emitter Update: The CMA-ES emitter updates its mean and covariance based on the improvement \(\Delta\), guiding the search toward regions of the behavior space where quality can be improved or new cells can be discovered.
Global Optimization with Ray Tune
For global search over large, potentially non-convex spaces with complex constraints, we utilize Ray Tune. The problem is formulated as:
[ {{global}} () ]
where \(\Theta_{global}\) can be defined by complex distributions (e.g., Log-Uniform, Categorical).
Ray Tune orchestrates the search using algorithms like:
- Bayesian Optimization: Uses a Gaussian Process surrogate model \(P(f \mid \mathcal{D})\) to approximate the objective and an acquisition function \(a(\theta)\) (e.g., Expected Improvement) to select the next sample: [ {next} = {} a() ]
- HyperOpt (TPE): Models \(p(\theta \mid y)\) using Tree-structured Parzen Estimators to sample promising candidates.
These methods are particularly useful for “warm-starting” the local search (CMA-ES) or finding the best family of parameters (e.g., finding the right order of magnitude for \(\lambda\)).