Skip to content

Ic3net and family#7

Open
daniel-debrun wants to merge 14 commits intomainfrom
ic3net_and_family
Open

Ic3net and family#7
daniel-debrun wants to merge 14 commits intomainfrom
ic3net_and_family

Conversation

@daniel-debrun
Copy link
Copy Markdown
Member

PR: Add IC3Net-family baselines with REINFORCE trainer and TrainingMonitor

Summary

Adds a complete IC3Net-family baseline suite (IC, IRIC, CommNet, IC3Net) to JaxMARL, implemented in JAX/Flax. Includes a REINFORCE-with-value-baseline trainer that runs a Python-level training loop with live terminal progress reporting via a TrainingMonitor (rich UI), periodic checkpoint saving, and Hydra config-driven experiment management. Supports Overcooked (v1, v3) and MPE environments out of the box.

Architecture

Model variants (all share the same training loop):

  • IC — Independent per-agent MLP, no communication
  • IRIC — Same architecture, different return formulation (mean_ratio=0)
  • CommNet — Continuous communication via message passing (soft attention)
  • IC3Net — CommNet + learned binary talk/silent gating (hard attention)

Each has both feedforward and LSTM-recurrent variants, selected via RECURRENT: true/false.

Training loop (ic3net_train.py):

  1. Inner rollout collection uses jax.lax.scan (JIT-compiled, fast)
  2. Outer loop is Python-level — enables live metric reporting and checkpoint I/O
  3. TrainingMonitorInterface wraps the rich TrainingMonitor (falls back to plain-text when not in a TTY)
  4. Checkpoints saved as Flax .msgpack at configurable intervals

Key design decisions

  • REINFORCE, not PPO: Following the original IC3Net paper and reference implementation. The talk/silent gating semantics require REINFORCE-style policy gradients.
  • Python outer loop: Trades marginal speed for live observability (progress bar, metrics, checkpoints). The JIT-compiled inner loop handles the heavy lifting.
  • Spatial obs flattening: Overcooked observations (B, H, W, C) are flattened to (B, H*W*C) before feeding to the network, matching the reference approach.
  • TrainingMonitor as a standalone module: Reusable across baseline families, not specific to IC3Net.

Testing

  • All 4 model variants (IC, IRIC, CommNet, IC3Net) × feedforward/LSTM initialise and train
  • Checkpoint save/load round-trips verified
  • Envs: Overcooked_v1, Overcooked_v2, Overcooked_v3

Usage

# Train IC3Net on MPE
python baselines/IC3Net/ic3net_train.py

# Train on Overcooked medium
python baselines/IC3Net/ic3net_train.py --config-name=ic3net_overcooked_medium

# Inference with checkpoint
python baselines/IC3Net/ic3net_infer.py --config-name=ic3net_mpe_infer MODEL_PATH=checkpoints/.../model.msgpack

daniel-debrun added 6 commits February 15, 2026 13:38
…nication handling and model architecture

- Updated Transition data structure to include talk actions for IC3Net.
- Modified network initialization to accommodate batch dimensions for communication actions.
- Enhanced the training loop to support shaped rewards and improved loss calculation with re-evaluation.
- Adjusted communication action handling in visualization scripts to ensure compatibility with new dimensions.
- Refined CommNet models to support multi-layer encoders and optimized LSTM cell usage.
- Improved logging to monitor mean reward per step and adjusted final output messages for clarity.
- Ensured consistent handling of communication actions across various components of the IC3Net framework.
Copilot AI review requested due to automatic review settings March 1, 2026 17:50
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an IC3Net-family baseline suite to JaxMARL (IC / IRIC / CommNet / IC3Net) with JAX/Flax models, a REINFORCE-style trainer + inference tooling, and registers two IC3Net reference grid environments (Predator-Prey and Traffic Junction).

Changes:

  • Added PredatorPreyGrid and TrafficJunctionGrid environments and registered them in jaxmarl.make.
  • Introduced IC3Net-family models (FF + LSTM), REINFORCE-with-value-baseline training loop, inference, visualization scripts, and Hydra configs.
  • Updated an IPPO baseline to use the correct per-agent action_space(...) API; expanded .gitignore for checkpoints/msgpacks.

Reviewed changes

Copilot reviewed 61 out of 62 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
jaxmarl/registration.py Registers predator_prey_grid and traffic_junction_grid in make() and registered_envs.
jaxmarl/environments/traffic_junction_grid.py Adds a JAX grid Traffic Junction environment (IC3Net reference task).
jaxmarl/environments/predator_prey_grid.py Adds a JAX grid Predator-Prey environment (IC3Net reference task).
jaxmarl/environments/init.py Exposes new grid env classes at package import level.
baselines/IPPO/ippo_ff_overcooked.py Fixes action-space access to use env.action_space("agent_0").
baselines/IC3Net/visualize_overcooked.py Adds an Overcooked visualization utility for IC3Net checkpoints.
baselines/IC3Net/train_overcooked.py Adds a quick Overcooked training demo script for IC3Net.
baselines/IC3Net/monitor.py Adds a Rich-based TrainingMonitor with a non-TTY fallback interface.
baselines/IC3Net/models.py Implements IC/CommNet/IC3Net model cores (FF + LSTM) in Flax.
baselines/IC3Net/ic3net_visualize.py Adds a GUI-based inference + visualization runner (matplotlib).
baselines/IC3Net/ic3net_train.py Adds Hydra-driven REINFORCE training loop w/ checkpoints + monitor + wandb hooks.
baselines/IC3Net/ic3net_infer.py Adds Hydra-driven inference/evaluation + optional GIF saving.
baselines/IC3Net/config/test_lstm.yaml Adds a quick LSTM smoke-test training config.
baselines/IC3Net/config/test_ff.yaml Adds a quick feedforward smoke-test training config.
baselines/IC3Net/config/iric_tj_medium.yaml Adds IRIC config for TrafficJunction medium.
baselines/IC3Net/config/iric_tj_hard.yaml Adds IRIC config for TrafficJunction hard.
baselines/IC3Net/config/iric_tj_easy.yaml Adds IRIC config for TrafficJunction easy.
baselines/IC3Net/config/iric_pp_medium.yaml Adds IRIC config for PredatorPrey medium.
baselines/IC3Net/config/iric_pp_hard.yaml Adds IRIC config for PredatorPrey hard.
baselines/IC3Net/config/iric_pp_easy.yaml Adds IRIC config for PredatorPrey easy.
baselines/IC3Net/config/iric_overcooked_medium.yaml Adds IRIC config for Overcooked medium.
baselines/IC3Net/config/iric_overcooked.yaml Adds IRIC config for Overcooked easy.
baselines/IC3Net/config/iric_mpe.yaml Adds IRIC config for MPE Simple Spread.
baselines/IC3Net/config/ic_tj_medium.yaml Adds IC config for TrafficJunction medium.
baselines/IC3Net/config/ic_tj_hard.yaml Adds IC config for TrafficJunction hard.
baselines/IC3Net/config/ic_tj_easy.yaml Adds IC config for TrafficJunction easy.
baselines/IC3Net/config/ic_pp_medium.yaml Adds IC config for PredatorPrey medium.
baselines/IC3Net/config/ic_pp_hard.yaml Adds IC config for PredatorPrey hard.
baselines/IC3Net/config/ic_pp_easy.yaml Adds IC config for PredatorPrey easy.
baselines/IC3Net/config/ic_overcooked_medium.yaml Adds IC config for Overcooked medium.
baselines/IC3Net/config/ic_overcooked.yaml Adds IC config for Overcooked easy.
baselines/IC3Net/config/ic_mpe.yaml Adds IC config for MPE Simple Spread.
baselines/IC3Net/config/ic3net_tj_medium.yaml Adds IC3Net config for TrafficJunction medium.
baselines/IC3Net/config/ic3net_tj_hard.yaml Adds IC3Net config for TrafficJunction hard.
baselines/IC3Net/config/ic3net_tj_easy_infer.yaml Adds IC3Net inference config for TrafficJunction easy.
baselines/IC3Net/config/ic3net_tj_easy.yaml Adds IC3Net training config for TrafficJunction easy.
baselines/IC3Net/config/ic3net_pp_medium_infer.yaml Adds IC3Net inference config for PredatorPrey medium.
baselines/IC3Net/config/ic3net_pp_medium.yaml Adds IC3Net training config for PredatorPrey medium.
baselines/IC3Net/config/ic3net_pp_hard_infer.yaml Adds IC3Net inference config for PredatorPrey hard.
baselines/IC3Net/config/ic3net_pp_hard.yaml Adds IC3Net training config for PredatorPrey hard.
baselines/IC3Net/config/ic3net_pp_easy_infer.yaml Adds IC3Net inference config for PredatorPrey easy.
baselines/IC3Net/config/ic3net_pp_easy.yaml Adds IC3Net training config for PredatorPrey easy.
baselines/IC3Net/config/ic3net_overcooked_v3_medium.yaml Adds IC3Net config for Overcooked v3 medium.
baselines/IC3Net/config/ic3net_overcooked_medium_test.yaml Adds a quick IC3Net Overcooked medium test config.
baselines/IC3Net/config/ic3net_overcooked_medium_infer.yaml Adds IC3Net inference config for Overcooked medium.
baselines/IC3Net/config/ic3net_overcooked_medium.yaml Adds IC3Net training config for Overcooked medium.
baselines/IC3Net/config/ic3net_overcooked_infer.yaml Adds IC3Net inference config for Overcooked easy.
baselines/IC3Net/config/ic3net_overcooked.yaml Adds IC3Net training config for Overcooked easy.
baselines/IC3Net/config/ic3net_mpe_infer.yaml Adds IC3Net inference config for MPE Simple Spread.
baselines/IC3Net/config/ic3net_mpe.yaml Adds IC3Net training config for MPE Simple Spread.
baselines/IC3Net/config/commnet_tj_medium.yaml Adds CommNet config for TrafficJunction medium.
baselines/IC3Net/config/commnet_tj_hard.yaml Adds CommNet config for TrafficJunction hard.
baselines/IC3Net/config/commnet_tj_easy.yaml Adds CommNet config for TrafficJunction easy.
baselines/IC3Net/config/commnet_pp_medium.yaml Adds CommNet config for PredatorPrey medium.
baselines/IC3Net/config/commnet_pp_hard.yaml Adds CommNet config for PredatorPrey hard.
baselines/IC3Net/config/commnet_pp_easy.yaml Adds CommNet config for PredatorPrey easy.
baselines/IC3Net/config/commnet_overcooked_medium.yaml Adds CommNet config for Overcooked medium.
baselines/IC3Net/config/commnet_overcooked.yaml Adds CommNet config for Overcooked easy.
baselines/IC3Net/config/commnet_mpe.yaml Adds CommNet config for MPE Simple Spread.
baselines/IC3Net/init.py Exposes baseline modules/classes for import.
baselines/IC3Net/README.md Adds baseline family documentation and quickstart commands.
.gitignore Ignores msgpack checkpoints and common output dirs.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +46 to +77
mode: str = "mixed",
):
super().__init__(num_agents)

self.dim = dim
self.vision = vision
self.max_steps = max_steps
self.mode = mode
self.npredator = num_agents
self.nprey = 1

# Grid constants
self.base = dim * dim
self.outside_class = self.base + 1
self.prey_class = self.base + 2
self.predator_class = self.base + 3
self.vocab_size = self.base + 4

# Vision window
self.vis_size = 2 * vision + 1

# Flat observation dim: vocab_size * vis_size * vis_size
self.obs_dim = self.vocab_size * self.vis_size * self.vis_size

# Agents
self.agents = [f"agent_{i}" for i in range(num_agents)]

# Reward constants (matching reference)
self.TIMESTEP_PENALTY = -0.05
self.PREY_REWARD = 0.0 # reward for agent ON prey in mixed mode
self.POS_PREY_REWARD = 0.05 # reward for agent ON prey in cooperative mode

Copy link

Copilot AI Mar 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mode argument and constants like POS_PREY_REWARD are defined, but step_env currently always uses the mixed-mode reward logic and never branches on self.mode. This makes mode effectively ignored and can confuse config-driven experiments (e.g., cooperative vs mixed). Either implement the intended mode-specific reward/done logic or remove the unused parameter/constants.

Copilot uses AI. Check for mistakes.
Comment on lines +207 to +231
# Loss function
def loss_fn(params):
policy_loss = -jnp.mean(log_probs * advantages)
value_loss = jnp.mean((values - returns) ** 2)
entropy_loss = -jnp.mean(entropies)

loss = (
policy_loss
+ config["VALUE_COEFF"] * value_loss
+ config["ENTROPY_COEFF"] * entropy_loss
)

return loss, {
"loss": loss,
"policy_loss": policy_loss,
"value_loss": value_loss,
"entropy": -entropy_loss,
}

# Compute gradients and update
(loss, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(
train_state.params
)
train_state = train_state.apply_gradients(grads=grads)

Copy link

Copilot AI Mar 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loss_fn(params) does not re-evaluate the policy/value outputs with the provided params; it instead uses log_probs, values, and entropies computed earlier during rollout collection. This typically breaks gradient flow (or at least makes it easy to accidentally differentiate w.r.t. stale values), so the update may produce zero/incorrect gradients. To make this a correct training example, recompute logits/values (and talk logits) inside loss_fn using params (similar to the re-evaluation pattern used in ic3net_train.py).

Copilot uses AI. Check for mistakes.
Comment on lines +50 to +75
def _build_network(config, num_agents, action_dim):
"""Build network based on baseline type and recurrence setting."""
baseline = config.get("BASELINE", "ic3net")
recurrent = config.get("RECURRENT", True)
hidden_dim = config.get("HIDDEN_DIM", 64)

if baseline in ("ic", "iric"):
if recurrent:
network = IndependentLSTM(action_dim=action_dim, hidden_dim=hidden_dim)
else:
network = IndependentMLP(action_dim=action_dim, hidden_dim=hidden_dim)
has_talk = False
else:
hard_attn = (baseline == "ic3net")
kw = dict(
num_agents=num_agents,
action_dim=action_dim,
hidden_dim=hidden_dim,
comm_passes=config.get("COMM_PASSES", 1),
comm_mode=config.get("COMM_MODE", "avg"),
hard_attn=hard_attn,
share_weights=config.get("SHARE_WEIGHTS", False),
encoder_layers=config.get("ENCODER_LAYERS", 1),
)
network = CommNetLSTM(**kw) if recurrent else CommNetDiscrete(**kw)
has_talk = hard_attn
Copy link

Copilot AI Mar 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code treats BASELINE: iric identically to BASELINE: ic (same network, same return/advantage computation). This conflicts with the PR description that IRIC uses a different return formulation (e.g., mean_ratio=0) and means selecting iric currently has no behavioral effect. Please either implement the IRIC-specific return logic in the update step (and document it in config) or adjust the PR description/configs to reflect that IRIC is currently equivalent to IC.

Copilot uses AI. Check for mistakes.
Comment on lines +138 to +145
# 10. Predator-Prey Grid (IC3Net reference)
elif env_id == "predator_prey_grid":
env = PredatorPreyGrid(**env_kwargs)

# 11. Traffic Junction Grid (IC3Net reference)
elif env_id == "traffic_junction_grid":
env = TrafficJunctionGrid(**env_kwargs)

Copy link

Copilot AI Mar 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New environment IDs are added to registration, but the environment API test suite (tests/test_envs_implement_base_api.py) maintains an explicit envs_to_test list and currently won’t exercise these new envs. Consider adding predator_prey_grid and traffic_junction_grid to that list (or switching the test to iterate over registered_envs) to ensure reset/step and space contains() checks run in CI for the new environments.

Copilot uses AI. Check for mistakes.
Comment on lines +188 to +199
if self.hard_attn:
if comm_action is None:
ca = jnp.zeros((batch_size, n))
else:
ca = comm_action.astype(jnp.float32)
if ca.ndim == 1:
ca = jnp.broadcast_to(ca, (batch_size, n))
ca_mask = ca.reshape(batch_size, 1, n)
ca_mask = jnp.broadcast_to(ca_mask, (batch_size, n, n))
ca_mask = jnp.expand_dims(ca_mask, -1)
agent_mask = agent_mask * ca_mask

Copy link

Copilot AI Mar 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In hard-attention mode, the talk/silent gating mask is currently built from comm_action.reshape(batch_size, 1, n) and broadcast to (B, n, n), which applies the talk mask on the receiver dimension. For IC3Net semantics, talk/silent typically controls whether an agent sends messages, so the mask should apply on the sender axis (e.g., shape (B, n, 1) before broadcasting) or otherwise ensure the sender dimension is gated. As written, silent agents may still send messages, and talking agents may be prevented from receiving instead of sending.

Copilot uses AI. Check for mistakes.
Comment on lines +466 to +477
if self.hard_attn:
if comm_action is None:
ca = jnp.zeros((batch_size, n))
else:
ca = comm_action.astype(jnp.float32)
if ca.ndim == 1:
ca = jnp.broadcast_to(ca, (batch_size, n))
ca_mask = ca.reshape(batch_size, 1, n)
ca_mask = jnp.broadcast_to(ca_mask, (batch_size, n, n))
ca_mask = jnp.expand_dims(ca_mask, -1)
agent_mask = agent_mask * ca_mask

Copy link

Copilot AI Mar 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as in CommNetDiscrete: the hard-attention comm_action mask is broadcast from shape (B, 1, N) to (B, N, N), which gates the receiver axis rather than the sender axis. If comm_action is meant to represent each agent's talk/silent decision, this should mask outgoing messages (sender dimension) to match IC3Net semantics; otherwise silent agents can still contribute to the aggregated message.

Copilot uses AI. Check for mistakes.
Comment on lines +79 to +83
self.observation_spaces = {
a: Box(0.0, 1.0, (self.vocab_size, self.vis_size, self.vis_size))
for a in self.agents
}
self.action_spaces = {a: Discrete(5) for a in self.agents}
Copy link

Copilot AI Mar 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

observation_spaces declares a Box with high=1.0 for all entries, but the observation construction uses .add(1.0) when marking predators. If multiple predators occupy the same cell (which is possible given movement allows overlaps), the predator channel can exceed 1.0 and violate the declared space (and any Space.contains checks). Consider making these marks binary (set/maximum/clip to 1.0) or increasing the Box high bound to accommodate counts.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants