Performance Analysis of JAX-MPPI Autotuning
This document outlines the performance bottlenecks and issues identified in the autotuning module of jax_mppi, specifically focusing on the evosax integration.
1. Architectural Bottleneck: Stateful vs Functional
The primary reason why autotune_evosax.py does not achieve expected performance gains over cma (CPU-based) is a fundamental mismatch between the Autotune framework architecture and JAX’s functional programming model.
- Current Architecture: The
Autotuneclass andTunableParameterinterface rely on a shared, mutableConfigStateHolder. Theevaluate_fnis a black-box function that relies on this side-effect-laden state update mechanism. - Impact: This prevents
vmap-ing the evaluation function over a population of parameters. JAX requires pure functions to parallelize execution. BecauseTunableParameter.apply_parameter_valuemodifies the global holder in-place, it cannot be safely used within ajax.vmaporjax.lax.scancontext without significant refactoring.
2. Sequential Evaluation in Evosax Optimization
In src/jax_mppi/autotune_evosax.py, the optimize_step method performs the following loop:
# Evaluate all solutions sequentially
results = []
fitness_values = []
for x in solutions:
result = self.evaluate_fn(np.array(x)) # type: ignore
results.append(result)
# ...- Issue: The population generated by
evosax(on GPU) is iterated over in Python. Each candidate solution is converted to a NumPy array, transferred to CPU, and evaluated individually. - Consequence: This completely negates the massive parallelization advantage of JAX. Instead of running
Nsimulations in parallel on the GPU, they are run sequentially (or with limited batching ifevaluate_fninternally batches, but typicallyevaluate_fnruns one configuration). - Comparison: While
cmais CPU-based and expects sequential/parallel CPU evaluation,evosaxis designed to run the entire ask-evaluate-tell loop on the GPU. The current implementation usesevosaxonly for the “ask” and “tell” steps, leaving the most expensive part (evaluation) to a slow Python loop.
3. Data Transfer Overhead
The interface forces repeated data movement between device and host:
solutions(fromes.ask) are JAX arrays on GPU.np.array(x)moves individual solution vectors to CPU.evaluate_fnlikely uses JAX internally, so it might move data back to GPU for simulation.- Results are moved back to CPU.
fitness_array = jnp.array(fitness_values)moves costs back to GPU fores.tell.
4. Lack of End-to-End JIT Compilation
evosax allows for the entire optimization process (multiple generations) to be JIT-compiled using jax.lax.scan.
- Current State:
optimize_stepis a Python method that cannot be JIT-compiled because it calls the Python-basedevaluate_fnloop. - Unused Code:
_create_jax_evaluate_fnexists inautotune_evosax.pybut is not utilized effectively to enable JAX-pure evaluation.
5. Issues in Quality Diversity (QD) Tuning
The same sequential evaluation pattern is present in src/jax_mppi/autotune_qd.py:
for solution in solutions:
result = self.evaluate_fn(solution)
results.append(result)This limits the scalability of the QD algorithms (CMA-ME), which typically benefit from large population sizes.
6. Minor Issues
- Hardcoded PRNG Key:
EvoSaxOptimizer.setup_optimizationresets the random key tojax.random.PRNGKey(0). This forces deterministic behavior that resets on every setup call, which might not be desired if the user wants to continue optimization or run multiple independent trials. - Type Hinting:
Autotune.optimize_allis typed to returnEvaluationResult, but can returnNoneifiterations=0. - Unused Variable:
self.jax_evaluate_fninEvoSaxOptimizeris assigned but never used.
Recommendations for Improvement
- Refactor
TunableParameter: Create a functional interface where parameters can be applied to a config/state to produce a new config/state without side effects. - Vectorized Evaluation: Update
Autotuneto support abatched_evaluate_fnthat accepts a batch of parameters (JAX array) and returns a batch of costs. - JIT-compile Loop: Once evaluation is vectorized and pure, use
jax.lax.scanto run the optimization loop entirely on the GPU.