Performance Analysis of JAX-MPPI Autotuning

This document outlines the performance bottlenecks and issues identified in the autotuning module of jax_mppi, specifically focusing on the evosax integration.

1. Architectural Bottleneck: Stateful vs Functional

The primary reason why autotune_evosax.py does not achieve expected performance gains over cma (CPU-based) is a fundamental mismatch between the Autotune framework architecture and JAX’s functional programming model.

  • Current Architecture: The Autotune class and TunableParameter interface rely on a shared, mutable ConfigStateHolder. The evaluate_fn is a black-box function that relies on this side-effect-laden state update mechanism.
  • Impact: This prevents vmap-ing the evaluation function over a population of parameters. JAX requires pure functions to parallelize execution. Because TunableParameter.apply_parameter_value modifies the global holder in-place, it cannot be safely used within a jax.vmap or jax.lax.scan context without significant refactoring.

2. Sequential Evaluation in Evosax Optimization

In src/jax_mppi/autotune_evosax.py, the optimize_step method performs the following loop:

# Evaluate all solutions sequentially
results = []
fitness_values = []

for x in solutions:
    result = self.evaluate_fn(np.array(x))  # type: ignore
    results.append(result)
    # ...
  • Issue: The population generated by evosax (on GPU) is iterated over in Python. Each candidate solution is converted to a NumPy array, transferred to CPU, and evaluated individually.
  • Consequence: This completely negates the massive parallelization advantage of JAX. Instead of running N simulations in parallel on the GPU, they are run sequentially (or with limited batching if evaluate_fn internally batches, but typically evaluate_fn runs one configuration).
  • Comparison: While cma is CPU-based and expects sequential/parallel CPU evaluation, evosax is designed to run the entire ask-evaluate-tell loop on the GPU. The current implementation uses evosax only for the “ask” and “tell” steps, leaving the most expensive part (evaluation) to a slow Python loop.

3. Data Transfer Overhead

The interface forces repeated data movement between device and host:

  1. solutions (from es.ask) are JAX arrays on GPU.
  2. np.array(x) moves individual solution vectors to CPU.
  3. evaluate_fn likely uses JAX internally, so it might move data back to GPU for simulation.
  4. Results are moved back to CPU.
  5. fitness_array = jnp.array(fitness_values) moves costs back to GPU for es.tell.

4. Lack of End-to-End JIT Compilation

evosax allows for the entire optimization process (multiple generations) to be JIT-compiled using jax.lax.scan.

  • Current State: optimize_step is a Python method that cannot be JIT-compiled because it calls the Python-based evaluate_fn loop.
  • Unused Code: _create_jax_evaluate_fn exists in autotune_evosax.py but is not utilized effectively to enable JAX-pure evaluation.

5. Issues in Quality Diversity (QD) Tuning

The same sequential evaluation pattern is present in src/jax_mppi/autotune_qd.py:

for solution in solutions:
    result = self.evaluate_fn(solution)
    results.append(result)

This limits the scalability of the QD algorithms (CMA-ME), which typically benefit from large population sizes.

6. Minor Issues

  • Hardcoded PRNG Key: EvoSaxOptimizer.setup_optimization resets the random key to jax.random.PRNGKey(0). This forces deterministic behavior that resets on every setup call, which might not be desired if the user wants to continue optimization or run multiple independent trials.
  • Type Hinting: Autotune.optimize_all is typed to return EvaluationResult, but can return None if iterations=0.
  • Unused Variable: self.jax_evaluate_fn in EvoSaxOptimizer is assigned but never used.

Recommendations for Improvement

  1. Refactor TunableParameter: Create a functional interface where parameters can be applied to a config/state to produce a new config/state without side effects.
  2. Vectorized Evaluation: Update Autotune to support a batched_evaluate_fn that accepts a batch of parameters (JAX array) and returns a batch of costs.
  3. JIT-compile Loop: Once evaluation is vectorized and pure, use jax.lax.scan to run the optimization loop entirely on the GPU.