Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions crazyflow/control/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,16 @@ def parametrize(
device: The device to use. If None, the device is inferred from the xp module.

Example:
```python
import numpy as np
from crazyflow.control import parametrize
from crazyflow.control.mellinger import state2attitude

ctrl = parametrize(state2attitude, "cf2x_L250")
pos = np.zeros(3)
quat = np.array([0.0, 0.0, 0.0, 1.0])
vel = np.zeros(3)
cmd = np.zeros(13)
rpyt, int_pos_err = ctrl(pos, quat, vel, cmd)
```
```python
import numpy as np
from crazyflow.control import parametrize
from crazyflow.control.mellinger import state2attitude

ctrl = parametrize(state2attitude, "cf2x_L250")
pos, quat = np.zeros(3), np.array([0.0, 0.0, 0.0, 1.0])
vel, cmd = np.zeros(3), np.zeros(13)
rpyt, int_pos_err = ctrl(pos, quat, vel, cmd)
```

Returns:
The parametrized controller function with all keyword argument only parameters filled in.
Expand Down
10 changes: 5 additions & 5 deletions crazyflow/dynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ def dynamics_features(dynamics: Callable) -> dict[str, bool]:
``ValueError``.

Example:
```python
from crazyflow.dynamics import dynamics_features
from crazyflow.dynamics.first_principles import dynamics
```python
from crazyflow.dynamics import dynamics_features
from crazyflow.dynamics.first_principles import dynamics

dynamics_features(dynamics) # {'rotor_dynamics': True}
```
dynamics_features(dynamics) # {'rotor_dynamics': True}
```
"""
if hasattr(dynamics, "func"): # Is a partial function
return dynamics_features(dynamics.func)
Expand Down
22 changes: 13 additions & 9 deletions crazyflow/dynamics/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,19 @@ def parametrize(
device: The device to use. If none, the device is inferred from the xp module.

Example:
```{ .python notest }
from crazyflow.dynamics.core import parametrize
from crazyflow.dynamics.first_principles import dynamics

dynamics_fn = parametrize(dynamics, drone="cf2x_L250")
pos_dot, quat_dot, vel_dot, ang_vel_dot, rotor_vel_dot = dynamics_fn(
pos=pos, quat=quat, vel=vel, ang_vel=ang_vel, cmd=cmd, rotor_vel=rotor_vel
)
```
```python
import numpy as np
from crazyflow.dynamics.core import parametrize
from crazyflow.dynamics.first_principles import dynamics

dynamics_fn = parametrize(dynamics, drone="cf2x_L250")
pos, quat = np.zeros(3), np.array([0.0, 0.0, 0.0, 1.0])
vel, ang_vel = np.zeros(3), np.zeros(3)
rotor_vel, cmd = np.zeros(4), np.zeros(4)
pos_dot, quat_dot, vel_dot, ang_vel_dot, rotor_vel_dot = dynamics_fn(
pos=pos, quat=quat, vel=vel, ang_vel=ang_vel, cmd=cmd, rotor_vel=rotor_vel
)
```

Returns:
The parametrized dynamics function with all keyword argument only parameters filled in.
Expand Down
10 changes: 10 additions & 0 deletions docs/examples/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ These runnable examples cover control, JAX transformations, pipeline extensions,

A single drone commanded to hold a fixed height using state control. This is the minimal end-to-end loop: create a `Sim`, reset it, apply a state command, and step forward.

<!-- notest: imported script, covered by tests/integration/test_examples.py -->
```{ .python notest }
--8<-- "examples/control/hover.py"
```
Expand All @@ -22,6 +23,7 @@ python examples/control/hover.py

Commanding roll, pitch, yaw, and collective thrust directly. This level bypasses the Mellinger position loop and is typical for RL agents that output attitude targets.

<!-- notest: imported script, covered by tests/integration/test_examples.py -->
```{ .python notest }
--8<-- "examples/control/attitude.py"
```
Expand All @@ -42,6 +44,7 @@ python examples/control/sampling.py

Because the simulator is built entirely from JAX operations, `jax.grad` can differentiate through it. Starting the drone above the target height keeps it away from the floor, so the floor-clipping stage never fires and gradients flow freely through the entire trajectory.

<!-- notest: imported script, covered by tests/integration/test_examples.py -->
```{ .python notest }
--8<-- "examples/jax/gradient.py"
```
Expand All @@ -52,6 +55,7 @@ Because the simulator is built entirely from JAX operations, `jax.grad` can diff

Randomizing mass and inertia through the reset pipeline. An optional mask limits randomization to selected worlds.

<!-- notest: imported script, covered by tests/integration/test_examples.py -->
```{ .python notest }
--8<-- "examples/plugins/randomize.py"
```
Expand All @@ -66,6 +70,7 @@ python examples/plugins/randomize.py

Inserting a random external force and torque into the step pipeline. The disturbance fires on every dynamics tick, so the drone fights wind-like perturbations.

<!-- notest: imported script, covered by tests/integration/test_examples.py -->
```{ .python notest }
--8<-- "examples/plugins/disturbance.py"
```
Expand All @@ -80,6 +85,7 @@ Offscreen rendering returns RGB and depth images on every frame. The FPV camera
<img src="../img/examples/cameras.gif" alt="RGB and depth camera outputs from a Crazyflow drone simulation">
</figure>

<!-- notest: imported script, covered by tests/integration/test_examples.py -->
```{ .python notest }
--8<-- "examples/rendering/cameras.py"
```
Expand All @@ -98,6 +104,7 @@ python examples/rendering/cameras.py
<img src="../img/examples/led_decks.png" alt="Crazyflow drones with runtime-controlled LED deck materials">
</figure>

<!-- notest: imported script, covered by tests/integration/test_examples.py -->
```{ .python notest }
--8<-- "examples/rendering/led_deck.py"
```
Expand All @@ -121,6 +128,7 @@ The default collision geometry is a sphere around the drone frame. `use_box_coll
</figure>
</div>

<!-- notest: imported script, covered by tests/integration/test_examples.py -->
```{ .python notest }
--8<-- "examples/contacts/contacts.py"
```
Expand All @@ -131,6 +139,7 @@ The default collision geometry is a sphere around the drone frame. `use_box_coll

`render_depth` fires rays from a camera and returns per-pixel distances. This is faster than full RGB rendering and useful for obstacle sensing or depth-based controllers.

<!-- notest: imported script, covered by tests/integration/test_examples.py -->
```{ .python notest }
--8<-- "examples/rendering/raycasting.py"
```
Expand All @@ -145,6 +154,7 @@ python examples/rendering/raycasting.py

Evaluating a random policy in the figure-8 environment. The env wraps `Sim` behind the standard Gymnasium `VectorEnv` interface.

<!-- notest: imported script, covered by tests/integration/test_examples.py -->
```{ .python notest }
--8<-- "examples/environments/figure8.py"
```
1 change: 1 addition & 0 deletions docs/user-guide/control/parametrize.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ float(params["mass"]) # 0.029

By default parameters are stored as NumPy arrays. Pass `xp` to convert them upfront, which avoids per-call conversion overhead in frameworks like PyTorch or JAX:

<!-- notest: requires torch (not a dependency) -->
```{ .python notest }
import torch
from crazyflow.control import parametrize
Expand Down
2 changes: 2 additions & 0 deletions docs/user-guide/dynamics/parametrize.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ If your drone is not listed, you can identify the parameters from flight data us

By default, `parametrize` stores parameters as NumPy arrays. For frameworks that would otherwise need to convert those arrays on every call — such as PyTorch, where NumPy arrays must become tensors — passing `xp` converts the parameters upfront. The backend of the outputs is always inferred from whatever arrays you pass in at call time.

<!-- notest: requires torch (not a dependency) -->
```{ .python notest }
import torch
import jax.numpy as jnp
Expand All @@ -49,6 +50,7 @@ dynamics_jax = parametrize(dynamics, drone="cf2x_L250", xp=jnp)

You can also specify a compute device — for example, to move JAX parameters to GPU at construction time:

<!-- notest: requires a GPU -->
```{ .python notest }
import jax
import jax.numpy as jnp
Expand Down
3 changes: 3 additions & 0 deletions docs/user-guide/dynamics/system-identification.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ After `preprocessing` + [`derivatives_svf`][crazyflow.dynamics.utils.data_utils.

## Full pipeline

<!-- notest: requires flight-log data -->
```{ .python notest }
from crazyflow.dynamics.utils.data_utils import preprocessing, derivatives_svf
from crazyflow.dynamics.utils.identification import sys_id_translation, sys_id_rotation
Expand Down Expand Up @@ -61,6 +62,7 @@ See [`sys_id_translation`][crazyflow.dynamics.utils.identification.sys_id_transl

To check that the identified parameters generalise to unseen flight regimes, collect a second dataset of different trajectories and pass it as `data_validation`. RMSE and R² are then reported on both the training data and the validation data.

<!-- notest: requires flight-log data -->
```{ .python notest }
# Preprocess the validation dataset independently — it must come from
# different trajectories, not a split of the same recording.
Expand Down Expand Up @@ -97,6 +99,7 @@ cmd_rpy_coef = [196.18, 196.18, 390.27] # from rot_params["cmd_rpy_coef"

Once the entry is in the TOML file, load the dynamics as usual:

<!-- notest: example drone not in params.toml -->
```{ .python notest }
from crazyflow.dynamics import parametrize
from crazyflow.dynamics.so_rpy_rotor_drag import dynamics
Expand Down
24 changes: 0 additions & 24 deletions docs/user-guide/functional-api.md
Comment thread
ratheron marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,6 @@ The object-oriented API is convenient for scripting, but it relies on Python-lev

The functional API addresses this by expressing the same operations as pure functions that take `SimData` and return updated `SimData`. There is no hidden state, so JAX can trace, compile, and differentiate through arbitrary compositions of these functions.

## What does not work inside JAX transformations

The object-oriented `Sim` methods mutate `sim.data` in place through Python calls. JAX cannot trace through Python-level state mutations, so these methods cannot be used inside `jax.jit`, `jax.grad`, or `jax.lax.scan`:

```{ .python notest }
import jax
import jax.numpy as jnp
from crazyflow.sim import Sim
from crazyflow.control import Control
sim = Sim(control=Control.attitude)
sim.reset()
@jax.jit
def broken(cmd):
sim.attitude_control(cmd) # mutates sim.data — JAX traces the ops but leaks the tracer
sim.step(1)
return sim.data.states.pos # sim.data now holds a leaked tracer; accessing it outside JIT raises UnexpectedTracerError
```

## What does work

The purely functional counterpart passes `SimData` explicitly and returns updated `SimData`. Every operation is a plain JAX function with no Python-level mutation, so the full simulation pipeline is traceable by any JAX transformation:

```python
import jax
import jax.numpy as jnp
Expand Down
12 changes: 11 additions & 1 deletion docs/user-guide/mujoco.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ sim.reset()

Loading from a file works identically:

<!-- notest: reads external asset assets/gate.xml -->
```{ .python notest }
import mujoco
gate_spec = mujoco.MjSpec.from_file("assets/gate.xml")
Expand All @@ -100,6 +101,7 @@ After `sim.step()` or `sim.reset()`, `mjx_synced` is set to `False`. The `sim.re

These run only once per render or contact call, regardless of how many dynamics steps were taken since the last sync.

<!-- notest: requires rendering -->
```{ .python notest }
for i in range(10):
sim.step(5) # JAX dynamics only, mjx_synced = False
Expand All @@ -113,6 +115,7 @@ for i in range(10):

This means the order of calls matters. Grouping all rendering and contact queries together after a step lets them share a single sync:

<!-- notest: requires rendering -->
```{ .python notest }
sim.step(5)
contacts = sim.contacts() # sync runs here
Expand All @@ -121,6 +124,7 @@ sim.render(mode="rgb_array") # flag already set, no second sync

Interleaving a step between them forces two syncs:

<!-- notest: requires rendering -->
```{ .python notest }
contacts = sim.contacts() # sync runs here
sim.step(5) # flag cleared
Expand All @@ -135,19 +139,25 @@ The solution is to **close over** `mjx_data` rather than pass it as an argument.

The drone racing environment in [lsy_drone_racing](https://github.com/learnsyslab/lsy_drone_racing) uses this pattern to build a contact check function:

```{ .python notest }
```python
from jax import Array

from crazyflow.sim import Sim
from crazyflow.sim.sim import sync_sim2mjx
from crazyflow.sim.data import SimData

sim = Sim(n_worlds=1, n_drones=1)
sim.reset()

_mjx_data = sim.mjx_data # captured in closure

def check_contacts(sim_data: SimData, obstacle_mocap_pos: Array) -> Array:
# Update obstacle positions and sync inside JIT
mjx_data = _mjx_data.replace(mocap_pos=obstacle_mocap_pos)
_, mjx_data = sync_sim2mjx(sim_data, mjx_data, sim.mjx_model)
return mjx_data._impl.contact.dist < 0

in_contact = check_contacts(sim.data, sim.mjx_data.mocap_pos)
```

`_mjx_data` is fused into the closure and compiled as a constant. Only `sim_data` and the obstacle positions cross the JIT boundary at runtime — a much smaller pytree than passing the full `mjx_data`.
Expand Down
1 change: 1 addition & 0 deletions docs/user-guide/oo-api.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ pos_w0_d1 = sim.data.states.pos[0, 1] # (3,)

`sim.render()` opens an interactive MuJoCo viewer or returns an image array for offscreen rendering.

<!-- notest: requires rendering -->
```{ .python notest }
sim.render() # interactive window, world 0
sim.render(mode="rgb_array") # returns (H, W, 3) uint8
Expand Down
5 changes: 5 additions & 0 deletions docs/user-guide/visualization.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Crazyflow supports onscreen interactive rendering and offscreen RGB/depth captur
| `"depth_array"` | `(H, W) float32` | Offscreen depth frame in metres |
| `"rgbd_tuple"` | `(rgb, depth)` | Both channels as a tuple |

<!-- notest: requires rendering -->
```{ .python notest }
sim.render() # interactive window
rgb = sim.render(mode="rgb_array") # numpy array (H, W, 3)
Expand All @@ -33,6 +34,7 @@ sim.close() # close the viewer

Pass a camera name or integer ID to select which camera to render from. The default (`camera=-1`) uses the free camera. Each drone ships with a first-person view camera named `fpv_cam:<drone_index>`:

<!-- notest: requires rendering -->
```{ .python notest }
sim.render(camera="fpv_cam:0") # first-person view from drone 0
sim.render(camera=0) # camera by integer ID
Expand All @@ -42,6 +44,7 @@ sim.render(camera=0) # camera by integer ID

For obstacle sensing or perception-based controllers, `render_depth` fires a ray from each camera pixel and returns per-pixel distances — faster than full RGB rendering because it skips lighting and colour computation:

<!-- notest: requires rendering -->
```{ .python notest }
import jax.numpy as jnp
from crazyflow.sim.sensors import build_render_depth_fn, render_depth
Expand All @@ -64,6 +67,7 @@ dist = render_fn(sim)

`change_material` updates the RGBA colour and emission intensity of any named material on any subset of drones without rebuilding the model:

<!-- notest: requires rendering -->
```{ .python notest }
import numpy as np
from crazyflow.sim.visualize import change_material
Expand All @@ -77,6 +81,7 @@ sim.render()

`sim.render()` always renders a single world at a time. Pass `world=<index>` to choose which one:

<!-- notest: requires rendering -->
```{ .python notest }
sim.render(world=0) # default
sim.render(world=3) # render world 3
Expand Down
Loading