Conversation
…nication handling and model architecture - Updated Transition data structure to include talk actions for IC3Net. - Modified network initialization to accommodate batch dimensions for communication actions. - Enhanced the training loop to support shaped rewards and improved loss calculation with re-evaluation. - Adjusted communication action handling in visualization scripts to ensure compatibility with new dimensions. - Refined CommNet models to support multi-layer encoders and optimized LSTM cell usage. - Improved logging to monitor mean reward per step and adjusted final output messages for clarity. - Ensured consistent handling of communication actions across various components of the IC3Net framework.
There was a problem hiding this comment.
Pull request overview
Adds an IC3Net-family baseline suite to JaxMARL (IC / IRIC / CommNet / IC3Net) with JAX/Flax models, a REINFORCE-style trainer + inference tooling, and registers two IC3Net reference grid environments (Predator-Prey and Traffic Junction).
Changes:
- Added
PredatorPreyGridandTrafficJunctionGridenvironments and registered them injaxmarl.make. - Introduced IC3Net-family models (FF + LSTM), REINFORCE-with-value-baseline training loop, inference, visualization scripts, and Hydra configs.
- Updated an IPPO baseline to use the correct per-agent
action_space(...)API; expanded.gitignorefor checkpoints/msgpacks.
Reviewed changes
Copilot reviewed 61 out of 62 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| jaxmarl/registration.py | Registers predator_prey_grid and traffic_junction_grid in make() and registered_envs. |
| jaxmarl/environments/traffic_junction_grid.py | Adds a JAX grid Traffic Junction environment (IC3Net reference task). |
| jaxmarl/environments/predator_prey_grid.py | Adds a JAX grid Predator-Prey environment (IC3Net reference task). |
| jaxmarl/environments/init.py | Exposes new grid env classes at package import level. |
| baselines/IPPO/ippo_ff_overcooked.py | Fixes action-space access to use env.action_space("agent_0"). |
| baselines/IC3Net/visualize_overcooked.py | Adds an Overcooked visualization utility for IC3Net checkpoints. |
| baselines/IC3Net/train_overcooked.py | Adds a quick Overcooked training demo script for IC3Net. |
| baselines/IC3Net/monitor.py | Adds a Rich-based TrainingMonitor with a non-TTY fallback interface. |
| baselines/IC3Net/models.py | Implements IC/CommNet/IC3Net model cores (FF + LSTM) in Flax. |
| baselines/IC3Net/ic3net_visualize.py | Adds a GUI-based inference + visualization runner (matplotlib). |
| baselines/IC3Net/ic3net_train.py | Adds Hydra-driven REINFORCE training loop w/ checkpoints + monitor + wandb hooks. |
| baselines/IC3Net/ic3net_infer.py | Adds Hydra-driven inference/evaluation + optional GIF saving. |
| baselines/IC3Net/config/test_lstm.yaml | Adds a quick LSTM smoke-test training config. |
| baselines/IC3Net/config/test_ff.yaml | Adds a quick feedforward smoke-test training config. |
| baselines/IC3Net/config/iric_tj_medium.yaml | Adds IRIC config for TrafficJunction medium. |
| baselines/IC3Net/config/iric_tj_hard.yaml | Adds IRIC config for TrafficJunction hard. |
| baselines/IC3Net/config/iric_tj_easy.yaml | Adds IRIC config for TrafficJunction easy. |
| baselines/IC3Net/config/iric_pp_medium.yaml | Adds IRIC config for PredatorPrey medium. |
| baselines/IC3Net/config/iric_pp_hard.yaml | Adds IRIC config for PredatorPrey hard. |
| baselines/IC3Net/config/iric_pp_easy.yaml | Adds IRIC config for PredatorPrey easy. |
| baselines/IC3Net/config/iric_overcooked_medium.yaml | Adds IRIC config for Overcooked medium. |
| baselines/IC3Net/config/iric_overcooked.yaml | Adds IRIC config for Overcooked easy. |
| baselines/IC3Net/config/iric_mpe.yaml | Adds IRIC config for MPE Simple Spread. |
| baselines/IC3Net/config/ic_tj_medium.yaml | Adds IC config for TrafficJunction medium. |
| baselines/IC3Net/config/ic_tj_hard.yaml | Adds IC config for TrafficJunction hard. |
| baselines/IC3Net/config/ic_tj_easy.yaml | Adds IC config for TrafficJunction easy. |
| baselines/IC3Net/config/ic_pp_medium.yaml | Adds IC config for PredatorPrey medium. |
| baselines/IC3Net/config/ic_pp_hard.yaml | Adds IC config for PredatorPrey hard. |
| baselines/IC3Net/config/ic_pp_easy.yaml | Adds IC config for PredatorPrey easy. |
| baselines/IC3Net/config/ic_overcooked_medium.yaml | Adds IC config for Overcooked medium. |
| baselines/IC3Net/config/ic_overcooked.yaml | Adds IC config for Overcooked easy. |
| baselines/IC3Net/config/ic_mpe.yaml | Adds IC config for MPE Simple Spread. |
| baselines/IC3Net/config/ic3net_tj_medium.yaml | Adds IC3Net config for TrafficJunction medium. |
| baselines/IC3Net/config/ic3net_tj_hard.yaml | Adds IC3Net config for TrafficJunction hard. |
| baselines/IC3Net/config/ic3net_tj_easy_infer.yaml | Adds IC3Net inference config for TrafficJunction easy. |
| baselines/IC3Net/config/ic3net_tj_easy.yaml | Adds IC3Net training config for TrafficJunction easy. |
| baselines/IC3Net/config/ic3net_pp_medium_infer.yaml | Adds IC3Net inference config for PredatorPrey medium. |
| baselines/IC3Net/config/ic3net_pp_medium.yaml | Adds IC3Net training config for PredatorPrey medium. |
| baselines/IC3Net/config/ic3net_pp_hard_infer.yaml | Adds IC3Net inference config for PredatorPrey hard. |
| baselines/IC3Net/config/ic3net_pp_hard.yaml | Adds IC3Net training config for PredatorPrey hard. |
| baselines/IC3Net/config/ic3net_pp_easy_infer.yaml | Adds IC3Net inference config for PredatorPrey easy. |
| baselines/IC3Net/config/ic3net_pp_easy.yaml | Adds IC3Net training config for PredatorPrey easy. |
| baselines/IC3Net/config/ic3net_overcooked_v3_medium.yaml | Adds IC3Net config for Overcooked v3 medium. |
| baselines/IC3Net/config/ic3net_overcooked_medium_test.yaml | Adds a quick IC3Net Overcooked medium test config. |
| baselines/IC3Net/config/ic3net_overcooked_medium_infer.yaml | Adds IC3Net inference config for Overcooked medium. |
| baselines/IC3Net/config/ic3net_overcooked_medium.yaml | Adds IC3Net training config for Overcooked medium. |
| baselines/IC3Net/config/ic3net_overcooked_infer.yaml | Adds IC3Net inference config for Overcooked easy. |
| baselines/IC3Net/config/ic3net_overcooked.yaml | Adds IC3Net training config for Overcooked easy. |
| baselines/IC3Net/config/ic3net_mpe_infer.yaml | Adds IC3Net inference config for MPE Simple Spread. |
| baselines/IC3Net/config/ic3net_mpe.yaml | Adds IC3Net training config for MPE Simple Spread. |
| baselines/IC3Net/config/commnet_tj_medium.yaml | Adds CommNet config for TrafficJunction medium. |
| baselines/IC3Net/config/commnet_tj_hard.yaml | Adds CommNet config for TrafficJunction hard. |
| baselines/IC3Net/config/commnet_tj_easy.yaml | Adds CommNet config for TrafficJunction easy. |
| baselines/IC3Net/config/commnet_pp_medium.yaml | Adds CommNet config for PredatorPrey medium. |
| baselines/IC3Net/config/commnet_pp_hard.yaml | Adds CommNet config for PredatorPrey hard. |
| baselines/IC3Net/config/commnet_pp_easy.yaml | Adds CommNet config for PredatorPrey easy. |
| baselines/IC3Net/config/commnet_overcooked_medium.yaml | Adds CommNet config for Overcooked medium. |
| baselines/IC3Net/config/commnet_overcooked.yaml | Adds CommNet config for Overcooked easy. |
| baselines/IC3Net/config/commnet_mpe.yaml | Adds CommNet config for MPE Simple Spread. |
| baselines/IC3Net/init.py | Exposes baseline modules/classes for import. |
| baselines/IC3Net/README.md | Adds baseline family documentation and quickstart commands. |
| .gitignore | Ignores msgpack checkpoints and common output dirs. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| mode: str = "mixed", | ||
| ): | ||
| super().__init__(num_agents) | ||
|
|
||
| self.dim = dim | ||
| self.vision = vision | ||
| self.max_steps = max_steps | ||
| self.mode = mode | ||
| self.npredator = num_agents | ||
| self.nprey = 1 | ||
|
|
||
| # Grid constants | ||
| self.base = dim * dim | ||
| self.outside_class = self.base + 1 | ||
| self.prey_class = self.base + 2 | ||
| self.predator_class = self.base + 3 | ||
| self.vocab_size = self.base + 4 | ||
|
|
||
| # Vision window | ||
| self.vis_size = 2 * vision + 1 | ||
|
|
||
| # Flat observation dim: vocab_size * vis_size * vis_size | ||
| self.obs_dim = self.vocab_size * self.vis_size * self.vis_size | ||
|
|
||
| # Agents | ||
| self.agents = [f"agent_{i}" for i in range(num_agents)] | ||
|
|
||
| # Reward constants (matching reference) | ||
| self.TIMESTEP_PENALTY = -0.05 | ||
| self.PREY_REWARD = 0.0 # reward for agent ON prey in mixed mode | ||
| self.POS_PREY_REWARD = 0.05 # reward for agent ON prey in cooperative mode | ||
|
|
There was a problem hiding this comment.
The mode argument and constants like POS_PREY_REWARD are defined, but step_env currently always uses the mixed-mode reward logic and never branches on self.mode. This makes mode effectively ignored and can confuse config-driven experiments (e.g., cooperative vs mixed). Either implement the intended mode-specific reward/done logic or remove the unused parameter/constants.
| # Loss function | ||
| def loss_fn(params): | ||
| policy_loss = -jnp.mean(log_probs * advantages) | ||
| value_loss = jnp.mean((values - returns) ** 2) | ||
| entropy_loss = -jnp.mean(entropies) | ||
|
|
||
| loss = ( | ||
| policy_loss | ||
| + config["VALUE_COEFF"] * value_loss | ||
| + config["ENTROPY_COEFF"] * entropy_loss | ||
| ) | ||
|
|
||
| return loss, { | ||
| "loss": loss, | ||
| "policy_loss": policy_loss, | ||
| "value_loss": value_loss, | ||
| "entropy": -entropy_loss, | ||
| } | ||
|
|
||
| # Compute gradients and update | ||
| (loss, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)( | ||
| train_state.params | ||
| ) | ||
| train_state = train_state.apply_gradients(grads=grads) | ||
|
|
There was a problem hiding this comment.
loss_fn(params) does not re-evaluate the policy/value outputs with the provided params; it instead uses log_probs, values, and entropies computed earlier during rollout collection. This typically breaks gradient flow (or at least makes it easy to accidentally differentiate w.r.t. stale values), so the update may produce zero/incorrect gradients. To make this a correct training example, recompute logits/values (and talk logits) inside loss_fn using params (similar to the re-evaluation pattern used in ic3net_train.py).
| def _build_network(config, num_agents, action_dim): | ||
| """Build network based on baseline type and recurrence setting.""" | ||
| baseline = config.get("BASELINE", "ic3net") | ||
| recurrent = config.get("RECURRENT", True) | ||
| hidden_dim = config.get("HIDDEN_DIM", 64) | ||
|
|
||
| if baseline in ("ic", "iric"): | ||
| if recurrent: | ||
| network = IndependentLSTM(action_dim=action_dim, hidden_dim=hidden_dim) | ||
| else: | ||
| network = IndependentMLP(action_dim=action_dim, hidden_dim=hidden_dim) | ||
| has_talk = False | ||
| else: | ||
| hard_attn = (baseline == "ic3net") | ||
| kw = dict( | ||
| num_agents=num_agents, | ||
| action_dim=action_dim, | ||
| hidden_dim=hidden_dim, | ||
| comm_passes=config.get("COMM_PASSES", 1), | ||
| comm_mode=config.get("COMM_MODE", "avg"), | ||
| hard_attn=hard_attn, | ||
| share_weights=config.get("SHARE_WEIGHTS", False), | ||
| encoder_layers=config.get("ENCODER_LAYERS", 1), | ||
| ) | ||
| network = CommNetLSTM(**kw) if recurrent else CommNetDiscrete(**kw) | ||
| has_talk = hard_attn |
There was a problem hiding this comment.
The code treats BASELINE: iric identically to BASELINE: ic (same network, same return/advantage computation). This conflicts with the PR description that IRIC uses a different return formulation (e.g., mean_ratio=0) and means selecting iric currently has no behavioral effect. Please either implement the IRIC-specific return logic in the update step (and document it in config) or adjust the PR description/configs to reflect that IRIC is currently equivalent to IC.
| # 10. Predator-Prey Grid (IC3Net reference) | ||
| elif env_id == "predator_prey_grid": | ||
| env = PredatorPreyGrid(**env_kwargs) | ||
|
|
||
| # 11. Traffic Junction Grid (IC3Net reference) | ||
| elif env_id == "traffic_junction_grid": | ||
| env = TrafficJunctionGrid(**env_kwargs) | ||
|
|
There was a problem hiding this comment.
New environment IDs are added to registration, but the environment API test suite (tests/test_envs_implement_base_api.py) maintains an explicit envs_to_test list and currently won’t exercise these new envs. Consider adding predator_prey_grid and traffic_junction_grid to that list (or switching the test to iterate over registered_envs) to ensure reset/step and space contains() checks run in CI for the new environments.
| if self.hard_attn: | ||
| if comm_action is None: | ||
| ca = jnp.zeros((batch_size, n)) | ||
| else: | ||
| ca = comm_action.astype(jnp.float32) | ||
| if ca.ndim == 1: | ||
| ca = jnp.broadcast_to(ca, (batch_size, n)) | ||
| ca_mask = ca.reshape(batch_size, 1, n) | ||
| ca_mask = jnp.broadcast_to(ca_mask, (batch_size, n, n)) | ||
| ca_mask = jnp.expand_dims(ca_mask, -1) | ||
| agent_mask = agent_mask * ca_mask | ||
|
|
There was a problem hiding this comment.
In hard-attention mode, the talk/silent gating mask is currently built from comm_action.reshape(batch_size, 1, n) and broadcast to (B, n, n), which applies the talk mask on the receiver dimension. For IC3Net semantics, talk/silent typically controls whether an agent sends messages, so the mask should apply on the sender axis (e.g., shape (B, n, 1) before broadcasting) or otherwise ensure the sender dimension is gated. As written, silent agents may still send messages, and talking agents may be prevented from receiving instead of sending.
| if self.hard_attn: | ||
| if comm_action is None: | ||
| ca = jnp.zeros((batch_size, n)) | ||
| else: | ||
| ca = comm_action.astype(jnp.float32) | ||
| if ca.ndim == 1: | ||
| ca = jnp.broadcast_to(ca, (batch_size, n)) | ||
| ca_mask = ca.reshape(batch_size, 1, n) | ||
| ca_mask = jnp.broadcast_to(ca_mask, (batch_size, n, n)) | ||
| ca_mask = jnp.expand_dims(ca_mask, -1) | ||
| agent_mask = agent_mask * ca_mask | ||
|
|
There was a problem hiding this comment.
Same issue as in CommNetDiscrete: the hard-attention comm_action mask is broadcast from shape (B, 1, N) to (B, N, N), which gates the receiver axis rather than the sender axis. If comm_action is meant to represent each agent's talk/silent decision, this should mask outgoing messages (sender dimension) to match IC3Net semantics; otherwise silent agents can still contribute to the aggregated message.
| self.observation_spaces = { | ||
| a: Box(0.0, 1.0, (self.vocab_size, self.vis_size, self.vis_size)) | ||
| for a in self.agents | ||
| } | ||
| self.action_spaces = {a: Discrete(5) for a in self.agents} |
There was a problem hiding this comment.
observation_spaces declares a Box with high=1.0 for all entries, but the observation construction uses .add(1.0) when marking predators. If multiple predators occupy the same cell (which is possible given movement allows overlaps), the predator channel can exceed 1.0 and violate the declared space (and any Space.contains checks). Consider making these marks binary (set/maximum/clip to 1.0) or increasing the Box high bound to accommodate counts.
…ompact observations and shaped rewards
… MAPPO and QLearning scripts
PR: Add IC3Net-family baselines with REINFORCE trainer and TrainingMonitor
Summary
Adds a complete IC3Net-family baseline suite (IC, IRIC, CommNet, IC3Net) to JaxMARL, implemented in JAX/Flax. Includes a REINFORCE-with-value-baseline trainer that runs a Python-level training loop with live terminal progress reporting via a
TrainingMonitor(rich UI), periodic checkpoint saving, and Hydra config-driven experiment management. Supports Overcooked (v1, v3) and MPE environments out of the box.Architecture
Model variants (all share the same training loop):
mean_ratio=0)Each has both feedforward and LSTM-recurrent variants, selected via
RECURRENT: true/false.Training loop (
ic3net_train.py):jax.lax.scan(JIT-compiled, fast)TrainingMonitorInterfacewraps the richTrainingMonitor(falls back to plain-text when not in a TTY).msgpackat configurable intervalsKey design decisions
(B, H, W, C)are flattened to(B, H*W*C)before feeding to the network, matching the reference approach.TrainingMonitoras a standalone module: Reusable across baseline families, not specific to IC3Net.Testing
Usage