Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
8a86438
🐛 gracefully handle VF2Layout no-solution-found
flowerthrower Jan 22, 2026
f34a9ae
disable progress bar
flowerthrower Jan 22, 2026
7a89375
🐛 only consider valid circuits for terminal action
flowerthrower Jan 22, 2026
af50584
🚧 debug predictor
flowerthrower Jan 22, 2026
1e3d3d4
🐛handle layout fail more gracefully
flowerthrower Jan 22, 2026
25c8a4a
🚧 restructure state machine
flowerthrower Jan 22, 2026
7eb5138
🚧 update state machine
flowerthrower Jan 22, 2026
37653fb
🎨 add strict policy
flowerthrower Jan 22, 2026
8802fbd
🚧 add og paper strategy
flowerthrower Jan 22, 2026
6db8ee7
🎨 fix og policy
flowerthrower Feb 9, 2026
dcc3810
⏪ remove thesis changes
flowerthrower Feb 17, 2026
440d54c
Merge remote-tracking branch 'origin/main' into bugfix
flowerthrower Feb 17, 2026
450d2dc
⏪ revert thesis updates
flowerthrower Feb 17, 2026
e69cf51
⏪ use og strategy
flowerthrower Feb 17, 2026
934e46c
🐛 fix no-layout found bug
flowerthrower Feb 17, 2026
825773f
Merge branch 'main' into bugfix
flowerthrower Feb 23, 2026
8b27982
Update src/mqt/predictor/rl/predictorenv.py
flowerthrower Feb 23, 2026
729ade7
🎨 docstring
flowerthrower Feb 27, 2026
983a4f6
Merge branch 'main' into bugfix
flowerthrower Mar 2, 2026
de36b57
🐛 fix routing check
flowerthrower Mar 3, 2026
bc667bd
Merge commit '983a4f61b72d10af05bfe762bb72f27aa9304fac' into bugfix
flowerthrower Mar 3, 2026
7c745fe
🐛 use find qubit
flowerthrower Mar 4, 2026
0d46fa9
🐛 fix no valid action bug
flowerthrower Mar 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 11 additions & 34 deletions src/mqt/predictor/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

if TYPE_CHECKING:
from qiskit import QuantumCircuit
from qiskit.circuit import QuantumRegister, Qubit
from qiskit.transpiler import Target
from sklearn.ensemble import RandomForestRegressor

Expand Down Expand Up @@ -62,44 +61,22 @@ def expected_fidelity(qc: QuantumCircuit, device: Target, precision: int = 10) -

if gate_type != "barrier":
assert len(qargs) in [1, 2]
first_qubit_idx = calc_qubit_index(qargs, qc.qregs, 0)
first_qubit_idx = qc.find_bit(qargs[0]).index

if len(qargs) == 1:
specific_fidelity = 1 - device[gate_type][first_qubit_idx,].error
else:
second_qubit_idx = calc_qubit_index(qargs, qc.qregs, 1)
specific_fidelity = 1 - device[gate_type][first_qubit_idx, second_qubit_idx].error

second_qubit_idx = qc.find_bit(qargs[1]).index
try:
specific_fidelity = 1 - device[gate_type][first_qubit_idx, second_qubit_idx].error
except KeyError:
msg = f"Error rate for gate {gate_type} on qubits {first_qubit_idx} and {second_qubit_idx} not found in device properties."
raise KeyError(msg) from None
res *= specific_fidelity

return float(np.round(res, precision).item())


def calc_qubit_index(qargs: list[Qubit], qregs: list[QuantumRegister], index: int) -> int:
"""Calculates the global qubit index for a given quantum circuit and qubit index.

Arguments:
qargs: The qubits of the quantum circuit.
qregs: The quantum registers of the quantum circuit.
index: The index of the qubit in the qargs list.

Returns:
The global qubit index of the given qubit in the quantum circuit.

Raises:
ValueError: If the qubit index is not found in the quantum registers.
"""
offset = 0
for reg in qregs:
if qargs[index] not in reg:
offset += reg.size
else:
qubit_index: int = offset + reg.index(qargs[index])
return qubit_index
error_msg = f"Global qubit index for local qubit {index} index not found."
raise ValueError(error_msg)


def estimated_success_probability(qc: QuantumCircuit, device: Target, precision: int = 10) -> float:
"""Calculates the estimated success probability of a given quantum circuit on a given device.

Expand All @@ -125,7 +102,7 @@ def estimated_success_probability(qc: QuantumCircuit, device: Target, precision:
if gate_type == "barrier" or gate_type == "id":
continue
assert len(qargs) in (1, 2)
first_qubit_idx = calc_qubit_index(qargs, qc.qregs, 0)
first_qubit_idx = qc.find_bit(qargs[0]).index
active_qubits.add(first_qubit_idx)

if len(qargs) == 1: # single-qubit gate
Expand All @@ -140,7 +117,7 @@ def estimated_success_probability(qc: QuantumCircuit, device: Target, precision:
))
exec_time_per_qubit[first_qubit_idx] += duration
else: # multi-qubit gate
second_qubit_idx = calc_qubit_index(qargs, qc.qregs, 1)
second_qubit_idx = qc.find_bit(qargs[1]).index
active_qubits.add(second_qubit_idx)
duration = device[gate_type][first_qubit_idx, second_qubit_idx].duration
op_times.append((gate_type, [first_qubit_idx, second_qubit_idx], duration, "s"))
Expand Down Expand Up @@ -191,7 +168,7 @@ def estimated_success_probability(qc: QuantumCircuit, device: Target, precision:
continue

assert len(qargs) in (1, 2)
first_qubit_idx = calc_qubit_index(qargs, qc.qregs, 0)
first_qubit_idx = scheduled_circ.find_bit(qargs[0]).index

if len(qargs) == 1:
if gate_type == "measure":
Expand All @@ -213,7 +190,7 @@ def estimated_success_probability(qc: QuantumCircuit, device: Target, precision:
continue
res *= 1 - device[gate_type][first_qubit_idx,].error
else:
second_qubit_idx = calc_qubit_index(qargs, qc.qregs, 1)
second_qubit_idx = scheduled_circ.find_bit(qargs[1]).index
res *= 1 - device[gate_type][first_qubit_idx, second_qubit_idx].error

if qiskit_version >= "2.0.0":
Expand Down
3 changes: 2 additions & 1 deletion src/mqt/predictor/rl/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@

from bqskit import Circuit
from pytket._tket.passes import BasePass as tket_BasePass
from qiskit.passmanager import PropertySet
from qiskit.transpiler.basepasses import BasePass as qiskit_BasePass


Expand Down Expand Up @@ -143,7 +144,7 @@ class DeviceDependentAction(Action):
Callable[..., tuple[Any, ...] | Circuit],
]
)
do_while: Callable[[dict[str, Circuit]], bool] | None = None
do_while: Callable[[PropertySet], bool] | None = None


# Registry of actions
Expand Down
7 changes: 4 additions & 3 deletions src/mqt/predictor/rl/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,12 @@ def train_model(
"""
if test:
set_random_seed(0) # for reproducibility
n_steps = 10
n_epochs = 1
batch_size = 10
n_steps = 32
n_epochs = 2
batch_size = 8
progress_bar = False
else:
set_random_seed(0)
# default PPO values
n_steps = 2048
n_epochs = 10
Expand Down
184 changes: 152 additions & 32 deletions src/mqt/predictor/rl/predictorenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from bqskit import Circuit
from qiskit.passmanager.base_tasks import Task
from qiskit.transpiler import Target
from qiskit.transpiler import Layout, Target

from mqt.predictor.reward import figure_of_merit
from mqt.predictor.rl.actions import Action
Expand All @@ -40,7 +40,6 @@
from qiskit import QuantumCircuit
from qiskit.passmanager.flow_controllers import DoWhileController
from qiskit.transpiler import CouplingMap, PassManager, TranspileLayout
from qiskit.transpiler.passes import CheckMap, GatesInBasis
from qiskit.transpiler.passes.layout.vf2_layout import VF2LayoutStopReason

from mqt.predictor.hellinger import get_hellinger_model_path
Expand Down Expand Up @@ -189,23 +188,21 @@ def step(self, action: int) -> tuple[dict[str, Any], float, bool, bool, dict[Any
self.state: QuantumCircuit = altered_qc
self.num_steps += 1

self.state._layout = self.layout # noqa: SLF001

self.valid_actions = self.determine_valid_actions_for_state()
if len(self.valid_actions) == 0:
msg = "No valid actions left."
raise RuntimeError(msg)

if action == self.action_terminate_index:
assert action in self.valid_actions, "Terminate action is not valid but was chosen."
reward_val = self.calculate_reward()
done = True
else:
reward_val = 0
done = False

# in case the Qiskit.QuantumCircuit has unitary or u gates in it, decompose them (because otherwise qiskit will throw an error when applying the BasisTranslator
if self.state.count_ops().get("unitary"): # ty: ignore[invalid-argument-type]
self.state = self.state.decompose(gates_to_decompose="unitary")

self.state._layout = self.layout # noqa: SLF001
obs = create_feature_dict(self.state)
return obs, reward_val, done, False, {}

Expand Down Expand Up @@ -268,10 +265,14 @@ def action_masks(self) -> list[bool]:
"""Returns a list of valid actions for the current state."""
action_mask = [action in self.valid_actions for action in self.action_set]

# it is not clear how tket will handle the layout, so we remove all actions that are from "origin"=="tket" if a layout is set
# TKET layout/optimization actions must not run after a Qiskit layout has been set
# (it is not clear how tket will handle the layout). TKET routing actions are
# designed to work after a Qiskit layout via PreProcessTKETRoutingAfterQiskitLayout.
if self.layout is not None:
action_mask = [
action_mask[i] and self.action_set[i].origin != CompilationOrigin.TKET for i in range(len(action_mask))
action_mask[i]
and (self.action_set[i].origin != CompilationOrigin.TKET or i in self.actions_routing_indices)
for i in range(len(action_mask))
]

if self.has_parameterized_gates or self.layout is not None:
Expand Down Expand Up @@ -342,9 +343,16 @@ def _apply_qiskit_action(self, action: Action, action_index: int) -> QuantumCirc
):
altered_qc = self._handle_qiskit_layout_postprocessing(action, pm, altered_qc)

elif action_index in self.actions_routing_indices and self.layout:
elif (
action_index in self.actions_routing_indices and self.layout and pm.property_set["final_layout"] is not None
):
self.layout.final_layout = pm.property_set["final_layout"]

# BasisTranslator errors on unitary gates; decompose them immediately so
# the circuit is always in a consistent state after a Qiskit action.
if altered_qc.count_ops().get("unitary"): # ty: ignore[invalid-argument-type]
altered_qc = altered_qc.decompose(gates_to_decompose="unitary")

return altered_qc

def _handle_qiskit_layout_postprocessing(
Expand All @@ -357,8 +365,13 @@ def _handle_qiskit_layout_postprocessing(
assert self.layout is not None
altered_qc, _ = postprocess_vf2postlayout(altered_qc, post_layout, self.layout)
elif action.name == "VF2Layout":
assert pm.property_set["VF2Layout_stop_reason"] == VF2LayoutStopReason.SOLUTION_FOUND
assert pm.property_set["layout"]
if pm.property_set["VF2Layout_stop_reason"] != VF2LayoutStopReason.SOLUTION_FOUND:
logger.warning(
"VF2Layout pass did not find a solution. Reason: %s",
pm.property_set["VF2Layout_stop_reason"],
)
else:
assert pm.property_set["layout"]
else:
assert pm.property_set["layout"]

Expand All @@ -385,7 +398,7 @@ def _apply_tket_action(self, action: Action, action_index: int) -> QuantumCircui

qbs = tket_qc.qubits
tket_qc.rename_units({qbs[i]: Qubit("q", i) for i in range(len(qbs))})
altered_qc = tk_to_qiskit(tket_qc)
altered_qc = tk_to_qiskit(tket_qc, replace_implicit_swaps=True)

if action_index in self.actions_routing_indices:
assert self.layout is not None
Expand Down Expand Up @@ -428,27 +441,134 @@ def _apply_bqskit_action(self, action: Action, action_index: int) -> QuantumCirc

return bqskit_to_qiskit(bqskit_compiled_qc)

def determine_valid_actions_for_state(self) -> list[int]:
"""Determines and returns the valid actions for the current state."""
check_nat_gates = GatesInBasis(basis_gates=self.device.operation_names)
check_nat_gates(self.state)
only_nat_gates = check_nat_gates.property_set["all_gates_in_basis"]
def is_circuit_laid_out(self, circuit: QuantumCircuit, layout: TranspileLayout | Layout) -> bool:
"""True if every logical qubit in the circuit has a physical assignment."""
if isinstance(layout, TranspileLayout):
# Use final_layout if available; otherwise fallback to initial_layout
layout = layout.final_layout or layout.initial_layout

if not only_nat_gates:
actions = self.actions_synthesis_indices + self.actions_opt_indices
if self.layout is not None:
actions += self.actions_routing_indices
return actions
v2p = layout.get_virtual_bits()
for instr in circuit.data:
for q in instr.qubits:
if q not in v2p:
# Logical qubit not assigned
return False
return True
Comment on lines +444 to +456
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Validate all logical qubits, not just those appearing in instructions.
The docstring says “every logical qubit”; idle qubits are currently skipped.

🔧 Suggested fix
-        for instr in circuit.data:
-            for q in instr.qubits:
-                if q not in v2p:
-                    # Logical qubit not assigned
-                    return False
+        for q in circuit.qubits:
+            if q not in v2p:
+                # Logical qubit not assigned
+                return False
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/mqt/predictor/rl/predictorenv.py` around lines 435 - 447, The current
is_circuit_laid_out only checks qubits that appear in instructions and misses
idle logical qubits; update is_circuit_laid_out (and its TranspileLayout
handling) to validate every logical qubit from the circuit (e.g., iterate
circuit.qubits or range(circuit.num_qubits)) against layout.get_virtual_bits()
instead of iterating only instr.qubits, and return False if any circuit qubit is
not present in v2p; keep the existing fallback to layout.final_layout or
layout.initial_layout and handle missing/None layout gracefully.


check_mapping = CheckMap(coupling_map=self.device.build_coupling_map())
check_mapping(self.state)
mapped = check_mapping.property_set["is_swap_mapped"]
def is_circuit_synthesized(self, circuit: QuantumCircuit) -> bool:
"""Check if the circuit uses only native gates of the device.

if mapped and self.layout is not None: # The circuit is correctly mapped.
return [self.action_terminate_index, *self.actions_opt_indices]
Verifies that every gate name in the circuit is present in
``device.operation_names``, equivalent to the ``GatesInBasis`` pass.

if self.layout is not None: # The circuit is not yet mapped but a layout is set.
return self.actions_routing_indices
Args:
circuit: QuantumCircuit to check.

# No layout applied yet
return self.actions_mapping_indices + self.actions_layout_indices + self.actions_opt_indices
Returns:
True if all gates are native to the device.
"""
native_names = set(self.device.operation_names)
return all(
instr.operation.name in native_names or instr.operation.name in ("barrier", "measure")
for instr in circuit.data
)

def is_circuit_routed(self, circuit: QuantumCircuit, coupling_map: CouplingMap) -> bool:
"""Check if a circuit is fully routed to the device, including directionality.

A circuit is considered routed if all two-qubit gates are on qubit pairs
that exist as directed edges in the device coupling map.

After a layout pass the circuit's qubits are already physical qubits, so
``circuit.find_bit(q).index`` gives the physical index directly —
consistent with how ``reward.py`` looks up gate calibrations.

Args:
circuit: QuantumCircuit to check.
coupling_map: CouplingMap of the target device.

Returns:
True if fully routed, False otherwise.
"""
directed_edges = set(coupling_map.get_edges())
for instr in circuit.data:
if len(instr.qubits) == 2:
q0 = circuit.find_bit(instr.qubits[0]).index
q1 = circuit.find_bit(instr.qubits[1]).index
if (q0, q1) not in directed_edges:
return False
return True

def determine_valid_actions_for_state(self) -> list[int]:
"""Determine valid actions based on circuit state: synthesized, mapped, routed."""
synthesized = self.is_circuit_synthesized(self.state)
laid_out = self.is_circuit_laid_out(self.state, self.layout) if self.layout else False
# Routing is only allowed after layout
routed = (
self.is_circuit_routed(self.state, CouplingMap(self.device.build_coupling_map())) if laid_out else False
)

actions = []

og = True # Original (restricted) MDP
flexible = False # General MDP

# Initial state
if not synthesized and not laid_out and not routed:
if flexible:
actions.extend(self.actions_synthesis_indices)
actions.extend(self.actions_mapping_indices)
actions.extend(self.actions_layout_indices)
actions.extend(self.actions_opt_indices)
if og:
actions.extend(self.actions_synthesis_indices)
actions.extend(self.actions_opt_indices)

if synthesized and not laid_out and not routed:
if flexible:
actions.extend(self.actions_mapping_indices)
actions.extend(self.actions_layout_indices)
actions.extend(self.actions_opt_indices)
if og:
actions.extend(self.actions_mapping_indices)
actions.extend(self.actions_layout_indices)
actions.extend(self.actions_opt_indices)

# Not *depicted* in paper; necessary because optimization can destroy the native gate set
if not synthesized and laid_out and not routed:
if flexible:
actions.extend(self.actions_synthesis_indices)
actions.extend(self.actions_routing_indices)
actions.extend(self.actions_opt_indices)
if og:
actions.extend(self.actions_synthesis_indices)
actions.extend(self.actions_routing_indices)
actions.extend(self.actions_opt_indices)

# Not *depicted* in paper; necessary because of mapping-only passes
if synthesized and laid_out and not routed:
if flexible:
actions.extend(self.actions_routing_indices)
actions.extend(self.actions_opt_indices)
if og:
actions.extend(self.actions_routing_indices)

# Not *depicted* in paper; necessary because routing can insert non-native SWAPs
if not synthesized and laid_out and routed:
if flexible:
actions.extend(self.actions_synthesis_indices)
actions.extend(self.actions_opt_indices)
if og:
actions.extend(self.actions_synthesis_indices)
actions.extend(self.actions_opt_indices)

# Final state
if synthesized and laid_out and routed:
if flexible:
actions.extend([self.action_terminate_index])
actions.extend(self.actions_opt_indices)
if og:
actions.extend([self.action_terminate_index])
actions.extend(self.actions_opt_indices)

return actions
2 changes: 1 addition & 1 deletion tests/compilation/test_integration_further_SDKs.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def test_tket_routing(available_actions_dict: dict[PassType, list[Action]]) -> N
qubit_map = {qbs[i]: Qubit("q", i) for i in range(len(qbs))}
tket_qc.rename_units(qubit_map)

mapped_qc = tk_to_qiskit(tket_qc)
mapped_qc = tk_to_qiskit(tket_qc, replace_implicit_swaps=True)

final_layout = final_layout_pytket_to_qiskit(tket_qc, mapped_qc)

Expand Down
Loading
Loading