Testing Guide
This guide explains the testing stack for jax_mppi and provides instructions on how to run and write tests.
Running Tests
The project uses pytest for running tests. You can run all tests using uv:
uv run pytestTo run a specific test file:
uv run pytest tests/test_mppi.pyTo run a specific test case:
uv run pytest tests/test_mppi.py::TestMPPICommand::test_command_returns_correct_shapesTest Suite Structure
The tests are located in the tests/ directory and mirror the source code structure where appropriate. The test suite is divided into several files, each covering a specific flavor or aspect of the library.
Core MPPI Flavors
tests/test_mppi.py: Tests for the base MPPI implementation (jax_mppi.mppi).- Goal: Ensure the correctness of the core algorithm, state management, and configuration options.
- Scope:
- Initialization: Verifies that
create()returns correct shapes and types forconfigandstate. - Command Generation: Tests the
command()function to ensure it generates valid actions within bounds and correctly updates the state. - Configuration Options: Validates various settings like
u_per_command(multi-step control),step_dependent_dynamics(time-varying systems),sample_null_action(ensuring baseline inclusion), andu_scale(control authority scaling). - Integration: Includes basic convergence tests to verify that the cost decreases over iterations (e.g.,
TestMPPIIntegration).
- Initialization: Verifies that
tests/test_smppi.py: Tests for Smooth MPPI (jax_mppi.smppi).- Goal: Verify that the “smooth” variant correctly operates in the lifted velocity control space and produces continuous action sequences.
- Scope:
- Lifted Space: Checks that the internal state (
U) represents control velocity/acceleration, whileaction_sequencerepresents the integrated actions. - Smoothness: Verifies that the smoothness cost penalty (
w_action_seq_cost) effectively reduces action variance. - Bounds: Tests that bounds are respected for both the control velocity (
u_min/u_max) and the final action (action_min/action_max). - Continuity: checks that the
shiftoperation maintains continuity in the action space, preventing jumps during receding horizon updates.
- Lifted Space: Checks that the internal state (
tests/test_kmppi.py: Tests for Kernel MPPI (jax_mppi.kmppi).- Goal: Ensure that kernel-based interpolation works correctly and that optimization occurs effectively in the reduced control point space.
- Scope:
- Kernels: Tests the properties of time-domain kernels (e.g.,
RBFKernel), such as shape and distance decay. - Interpolation: Verifies that control points (
theta) are correctly mapped to full trajectories (U) via_kernel_interpolate, preserving values at control points. - Optimization: Checks that the MPPI update rule is applied to the control points (
theta) rather than the full trajectory. - Smoothness: Confirms that the resulting trajectories are smooth due to the kernel properties (e.g., by checking second derivatives).
- Kernels: Tests the properties of time-domain kernels (e.g.,
Integration & Examples
tests/test_pendulum.py: End-to-end integration tests using a Pendulum environment.- Goal: Validate that the algorithms can solve a concrete, non-linear control task.
- Scope:
- Stabilization: Tests if MPPI can stabilize the pendulum at the upright position.
- Swing-up: Tests the more difficult task of swinging up from a hanging position.
- Physics: Sanity checks the pendulum dynamics and cost functions.
Autotuning
tests/test_autotune.py: Unit tests for the autotuning framework (jax_mppi.autotune).- Goal: Verify the components of the hyperparameter optimization system.
tests/test_autotune_integration.py: Integration tests for autotuning.- Goal: Ensure that the autotuner can successfully improve performance on a benchmark task (finding better parameters than the default).
Writing New Tests
When adding new features or fixing bugs, please add corresponding tests.
- Locate the appropriate test file: If you are modifying
mppi.py, add tests totests/test_mppi.py. - Use Class-Based Structure: Group related tests into classes (e.g.,
TestMPPIBasics,TestMPPICommand). - Property-Based Testing: Where possible, test properties (e.g., “output shape depends on input shape in this way”) rather than just hardcoded values.
- Integration Tests: For significant algorithmic changes, ensure that
tests/test_pendulum.pystill passes or add a similar simple control task to verify efficacy. - JAX Compatibility: Ensure tests check that functions can be JIT-compiled if they are intended to be used within
jax.jit.
Example Test Case
def test_new_feature(self):
nx, nu = 2, 1
config, state = mppi.create(nx=nx, nu=nu, noise_sigma=jnp.eye(nu))
# ... perform action ...
action, new_state = mppi.command(config, state, ...)
# ... assert expected behavior ...
assert action.shape == (nu,)