Skip to content

Fix GraphEdgeActionMLP backward policy bugs#499

Draft
josephdviviano wants to merge 1 commit intomasterfrom
fix-graph-mlp-backward
Draft

Fix GraphEdgeActionMLP backward policy bugs#499
josephdviviano wants to merge 1 commit intomasterfrom
fix-graph-mlp-backward

Conversation

@josephdviviano
Copy link
Copy Markdown
Collaborator

Summary

  • Fixes two bugs in GraphEdgeActionMLP.forward() when is_backward=True that caused crashes during TB loss computation with recalculate_all_logprobs=True
  • Bug 1: node_class_logits was only assigned in the forward branch but referenced unconditionally. Fixed by computing class logits unconditionally (matching GraphActionGNN pattern).
  • Bug 2: states_tensor.x (shape (N, 1)) was not squeezed before nn.Embedding, producing 3D node_index_logits. Fixed by adding .squeeze(-1) (matching GraphActionGNN).

Tests added

  • test_graph_module_output_shapes: 8 parametrized cases (GNN/MLP x forward/backward x directed/undirected)
  • test_graph_module_output_shapes_on_empty_graphs: 4 cases for initial (empty) states
  • test_graph_tb_pipeline: 4 cases running a full TB training step with recalculate_all_logprobs=True
  • test_graph_ring_smoke: parametrized with use_gnn=[True, False]

Test plan

  • All 1527 existing tests pass
  • 16 new graph module tests pass
  • CI smoke test for graph_ring with MLP

🤖 Generated with Claude Code

Two bugs in GraphEdgeActionMLP.forward() when is_backward=True:

1. node_class_logits was only assigned in the forward branch but
   referenced unconditionally in the return TensorDict, causing
   UnboundLocalError. Fix: compute node_class_logits and
   edge_class_logits unconditionally before the if/else branch,
   matching GraphActionGNN pattern.

2. states_tensor.x (shape: total_nodes, 1) was passed directly to
   nn.Embedding without squeezing, producing 3D output that
   propagated through pad_sequence into 3D node_index_logits.
   Fix: squeeze x before embedding, matching GraphActionGNN.

Tests added:
- test_graph_module_output_shapes (8 parametrized cases)
- test_graph_module_output_shapes_on_empty_graphs (4 cases)
- test_graph_tb_pipeline (4 cases)
- test_graph_ring_smoke parametrized with use_gnn=[True, False]

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@codecov
Copy link
Copy Markdown

codecov bot commented Mar 24, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.20%. Comparing base (f0605a8) to head (612d395).

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #499      +/-   ##
==========================================
+ Coverage   72.48%   73.20%   +0.72%     
==========================================
  Files          55       55              
  Lines        8519     8518       -1     
  Branches     1090     1090              
==========================================
+ Hits         6175     6236      +61     
+ Misses       1957     1896      -61     
+ Partials      387      386       -1     
Files with missing lines Coverage Δ
src/gfn/utils/modules.py 81.36% <100.00%> (+8.78%) ⬆️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant