Skip to content

Commit 524fb47

Browse files
authored
Implement Index<NodeIndex> for DAGCircuit (Qiskit#13683)
This removes the syntax noise of the `dag.dag()` calls when indexing by `NodeIndex`. As it happens, this is _almost_ all of the reason we even use the underlying graph object in `accelerate`. The only exceptions are some needless defensive programming in `RemoveDiagonalGatesBeforeMeasure` (which really is the same thing underneath anyway), and in the graphviz utilities, which is legitimate.
1 parent 3150351 commit 524fb47

15 files changed

+47
-41
lines changed

crates/accelerate/src/barrier_before_final_measurement.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ pub fn barrier_before_final_measurements(
3939
}
4040
dag.bfs_successors(node)
4141
.all(|(_, child_successors)| {
42-
child_successors.iter().all(|suc| match dag.dag()[*suc] {
42+
child_successors.iter().all(|suc| match dag[*suc] {
4343
NodeType::Operation(ref suc_inst) => is_exactly_final(suc_inst),
4444
_ => true,
4545
})
@@ -57,7 +57,7 @@ pub fn barrier_before_final_measurements(
5757
let final_packed_ops: Vec<PackedInstruction> = ordered_node_indices
5858
.into_iter()
5959
.map(|node| {
60-
let NodeType::Operation(ref inst) = dag.dag()[node] else {
60+
let NodeType::Operation(ref inst) = dag[node] else {
6161
unreachable!()
6262
};
6363
let res = inst.clone();

crates/accelerate/src/basis/basis_translator/mod.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ fn apply_translation(
428428
let mut is_updated = false;
429429
let mut out_dag = dag.copy_empty_like(py, "alike")?;
430430
for node in dag.topological_op_nodes()? {
431-
let node_obj = dag.dag()[node].unwrap_operation();
431+
let node_obj = dag[node].unwrap_operation();
432432
let node_qarg = dag.get_qargs(node_obj.qubits);
433433
let node_carg = dag.get_cargs(node_obj.clbits);
434434
let qubit_set: HashSet<Qubit> = HashSet::from_iter(node_qarg.iter().copied());
@@ -606,7 +606,7 @@ fn replace_node(
606606
}
607607
if node.params_view().is_empty() {
608608
for inner_index in target_dag.topological_op_nodes()? {
609-
let inner_node = &target_dag.dag()[inner_index].unwrap_operation();
609+
let inner_node = &target_dag[inner_index].unwrap_operation();
610610
let old_qargs = dag.get_qargs(node.qubits);
611611
let old_cargs = dag.get_cargs(node.clbits);
612612
let new_qubits: Vec<Qubit> = target_dag
@@ -667,7 +667,7 @@ fn replace_node(
667667
.zip(node.params_view())
668668
.into_py_dict_bound(py);
669669
for inner_index in target_dag.topological_op_nodes()? {
670-
let inner_node = &target_dag.dag()[inner_index].unwrap_operation();
670+
let inner_node = &target_dag[inner_index].unwrap_operation();
671671
let old_qargs = dag.get_qargs(node.qubits);
672672
let old_cargs = dag.get_cargs(node.clbits);
673673
let new_qubits: Vec<Qubit> = target_dag

crates/accelerate/src/commutation_analysis.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ pub(crate) fn analyze_commutations_inner(
8080
// if the node is an input/output node, they do not commute, so we only
8181
// continue if the nodes are operation nodes
8282
if let (NodeType::Operation(packed_inst0), NodeType::Operation(packed_inst1)) =
83-
(&dag.dag()[current_gate_idx], &dag.dag()[*prev_gate_idx])
83+
(&dag[current_gate_idx], &dag[*prev_gate_idx])
8484
{
8585
let op1 = packed_inst0.op.view();
8686
let op2 = packed_inst1.op.view();

crates/accelerate/src/commutation_cancellation.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,14 @@ pub(crate) fn cancel_commutations(
105105
if let Some(wire_commutation_set) = commutation_set.get(&Wire::Qubit(wire)) {
106106
for (com_set_idx, com_set) in wire_commutation_set.iter().enumerate() {
107107
if let Some(&nd) = com_set.first() {
108-
if !matches!(dag.dag()[nd], NodeType::Operation(_)) {
108+
if !matches!(dag[nd], NodeType::Operation(_)) {
109109
continue;
110110
}
111111
} else {
112112
continue;
113113
}
114114
for node in com_set.iter() {
115-
let instr = match &dag.dag()[*node] {
115+
let instr = match &dag[*node] {
116116
NodeType::Operation(instr) => instr,
117117
_ => panic!("Unexpected type in commutation set."),
118118
};
@@ -198,7 +198,7 @@ pub(crate) fn cancel_commutations(
198198
let mut total_angle: f64 = 0.0;
199199
let mut total_phase: f64 = 0.0;
200200
for current_node in cancel_set {
201-
let node_op = match &dag.dag()[*current_node] {
201+
let node_op = match &dag[*current_node] {
202202
NodeType::Operation(instr) => instr,
203203
_ => panic!("Unexpected type in commutation set run."),
204204
};

crates/accelerate/src/consolidate_blocks.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ pub(crate) fn consolidate_blocks(
100100
block_qargs.clear();
101101
if block.len() == 1 {
102102
let inst_node = block[0];
103-
let inst = dag.dag()[inst_node].unwrap_operation();
103+
let inst = dag[inst_node].unwrap_operation();
104104
if !is_supported(
105105
target,
106106
basis_gates.as_ref(),
@@ -123,7 +123,7 @@ pub(crate) fn consolidate_blocks(
123123
let mut basis_count: usize = 0;
124124
let mut outside_basis = false;
125125
for node in &block {
126-
let inst = dag.dag()[*node].unwrap_operation();
126+
let inst = dag[*node].unwrap_operation();
127127
block_qargs.extend(dag.get_qargs(inst.qubits));
128128
all_block_gates.insert(*node);
129129
if inst.op.name() == basis_gate_name {
@@ -151,7 +151,7 @@ pub(crate) fn consolidate_blocks(
151151
block_qargs.len() as u32,
152152
0,
153153
block.iter().map(|node| {
154-
let inst = dag.dag()[*node].unwrap_operation();
154+
let inst = dag[*node].unwrap_operation();
155155

156156
Ok((
157157
inst.op.clone(),
@@ -242,7 +242,7 @@ pub(crate) fn consolidate_blocks(
242242
continue;
243243
}
244244
let first_inst_node = run[0];
245-
let first_inst = dag.dag()[first_inst_node].unwrap_operation();
245+
let first_inst = dag[first_inst_node].unwrap_operation();
246246
let first_qubits = dag.get_qargs(first_inst.qubits);
247247

248248
if run.len() == 1
@@ -272,7 +272,7 @@ pub(crate) fn consolidate_blocks(
272272
if all_block_gates.contains(node) {
273273
already_in_block = true;
274274
}
275-
let gate = dag.dag()[*node].unwrap_operation();
275+
let gate = dag[*node].unwrap_operation();
276276
let operator = match get_matrix_from_inst(py, gate) {
277277
Ok(mat) => mat,
278278
Err(_) => {

crates/accelerate/src/convert_2q_block_matrix.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ pub fn blocks_to_matrix(
7272
let mut one_qubit_components_modified = false;
7373
let mut output_matrix: Option<Array2<Complex64>> = None;
7474
for node in op_list {
75-
let inst = dag.dag()[*node].unwrap_operation();
75+
let inst = dag[*node].unwrap_operation();
7676
let op_matrix = get_matrix_from_inst(py, inst)?;
7777
match dag
7878
.get_qargs(inst.qubits)

crates/accelerate/src/elide_permutations.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ fn run(py: Python, dag: &mut DAGCircuit) -> PyResult<Option<(DAGCircuit, Vec<usi
3939
// note that DAGCircuit::copy_empty_like clones the interners
4040
let mut new_dag = dag.copy_empty_like(py, "alike")?;
4141
for node_index in dag.topological_op_nodes()? {
42-
if let NodeType::Operation(inst) = &dag.dag()[node_index] {
42+
if let NodeType::Operation(inst) = &dag[node_index] {
4343
match (inst.op.name(), inst.condition()) {
4444
("swap", None) => {
4545
let qargs = dag.get_qargs(inst.qubits);

crates/accelerate/src/euler_one_qubit_decomposer.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -1089,7 +1089,7 @@ pub(crate) fn optimize_1q_gates_decomposition(
10891089
Some(_) => 1.,
10901090
None => raw_run.len() as f64,
10911091
};
1092-
let qubit: PhysicalQubit = if let NodeType::Operation(inst) = &dag.dag()[raw_run[0]] {
1092+
let qubit: PhysicalQubit = if let NodeType::Operation(inst) = &dag[raw_run[0]] {
10931093
PhysicalQubit::new(dag.get_qargs(inst.qubits)[0].0)
10941094
} else {
10951095
unreachable!("nodes in runs will always be op nodes")
@@ -1175,7 +1175,7 @@ pub(crate) fn optimize_1q_gates_decomposition(
11751175
let operator = raw_run
11761176
.iter()
11771177
.map(|node_index| {
1178-
let node = &dag.dag()[*node_index];
1178+
let node = &dag[*node_index];
11791179
if let NodeType::Operation(inst) = node {
11801180
if let Some(target) = target {
11811181
error *= compute_error_term_from_target(inst.op.name(), target, qubit);
@@ -1218,7 +1218,7 @@ pub(crate) fn optimize_1q_gates_decomposition(
12181218
let mut outside_basis = false;
12191219
if let Some(basis) = basis_gates {
12201220
for node in &raw_run {
1221-
if let NodeType::Operation(inst) = &dag.dag()[*node] {
1221+
if let NodeType::Operation(inst) = &dag[*node] {
12221222
if !basis.contains(inst.op.name()) {
12231223
outside_basis = true;
12241224
break;

crates/accelerate/src/gate_direction.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ where
351351
}
352352

353353
for (node, op_blocks) in ops_to_replace {
354-
let packed_inst = dag.dag()[node].unwrap_operation();
354+
let packed_inst = dag[node].unwrap_operation();
355355
let OperationRef::Instruction(py_inst) = packed_inst.op.view() else {
356356
panic!("PyInstruction is expected");
357357
};

crates/accelerate/src/inverse_cancellation.rs

+9-10
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ fn run_on_self_inverse(
6565
let mut chunk: Vec<NodeIndex> = Vec::new();
6666
let max_index = gate_cancel_run.len() - 1;
6767
for (i, cancel_gate) in gate_cancel_run.iter().enumerate() {
68-
let node = &dag.dag()[*cancel_gate];
68+
let node = &dag[*cancel_gate];
6969
if let NodeType::Operation(inst) = node {
7070
if gate_eq(py, inst, &gate)? {
7171
chunk.push(*cancel_gate);
@@ -78,13 +78,12 @@ fn run_on_self_inverse(
7878
if i == max_index {
7979
partitions.push(std::mem::take(&mut chunk));
8080
} else {
81-
let next_qargs = if let NodeType::Operation(next_inst) =
82-
&dag.dag()[gate_cancel_run[i + 1]]
83-
{
84-
next_inst.qubits
85-
} else {
86-
panic!("Not an op node")
87-
};
81+
let next_qargs =
82+
if let NodeType::Operation(next_inst) = &dag[gate_cancel_run[i + 1]] {
83+
next_inst.qubits
84+
} else {
85+
panic!("Not an op node")
86+
};
8887
if inst.qubits != next_qargs {
8988
partitions.push(std::mem::take(&mut chunk));
9089
}
@@ -132,8 +131,8 @@ fn run_on_inverse_pairs(
132131
for nodes in runs {
133132
let mut i = 0;
134133
while i < nodes.len() - 1 {
135-
if let NodeType::Operation(inst) = &dag.dag()[nodes[i]] {
136-
if let NodeType::Operation(next_inst) = &dag.dag()[nodes[i + 1]] {
134+
if let NodeType::Operation(inst) = &dag[nodes[i]] {
135+
if let NodeType::Operation(next_inst) = &dag[nodes[i + 1]] {
137136
if inst.qubits == next_inst.qubits
138137
&& ((gate_eq(py, inst, &gate_0)? && gate_eq(py, next_inst, &gate_1)?)
139138
|| (gate_eq(py, inst, &gate_1)?

crates/accelerate/src/remove_diagonal_gates_before_measure.rs

+2-3
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ fn run_remove_diagonal_before_measure(dag: &mut DAGCircuit) -> PyResult<()> {
5353
.next()
5454
.expect("index is an operation node, so it must have a predecessor.");
5555

56-
match &dag.dag()[predecessor] {
56+
match &dag[predecessor] {
5757
NodeType::Operation(pred_inst) => match pred_inst.standard_gate() {
5858
Some(gate) => {
5959
if DIAGONAL_1Q_GATES.contains(&gate) {
@@ -64,8 +64,7 @@ fn run_remove_diagonal_before_measure(dag: &mut DAGCircuit) -> PyResult<()> {
6464
let successors = dag.quantum_successors(predecessor);
6565
let remove_s = successors
6666
.map(|s| {
67-
let node_s = &dag.dag()[s];
68-
if let NodeType::Operation(inst_s) = node_s {
67+
if let NodeType::Operation(inst_s) = &dag[s] {
6968
inst_s.op.name() == "measure"
7069
} else {
7170
false

crates/accelerate/src/split_2q_unitaries.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ pub fn split_2q_unitaries(
3434
let nodes: Vec<NodeIndex> = dag.op_node_indices(false).collect();
3535

3636
for node in nodes {
37-
if let NodeType::Operation(inst) = &dag.dag()[node] {
37+
if let NodeType::Operation(inst) = &dag[node] {
3838
let qubits = dag.get_qargs(inst.qubits).to_vec();
3939
// We only attempt to split UnitaryGate objects, but this could be extended in future
4040
// -- however we need to ensure that we can compile the resulting single-qubit unitaries

crates/accelerate/src/unitary_synthesis.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ fn apply_synth_dag(
112112
synth_dag: &DAGCircuit,
113113
) -> PyResult<()> {
114114
for out_node in synth_dag.topological_op_nodes()? {
115-
let mut out_packed_instr = synth_dag.dag()[out_node].unwrap_operation().clone();
115+
let mut out_packed_instr = synth_dag[out_node].unwrap_operation().clone();
116116
let synth_qargs = synth_dag.get_qargs(out_packed_instr.qubits);
117117
let mapped_qargs: Vec<Qubit> = synth_qargs
118118
.iter()
@@ -237,7 +237,7 @@ fn py_run_main_loop(
237237

238238
// Iterate over dag nodes and determine unitary synthesis approach
239239
for node in dag.topological_op_nodes()? {
240-
let mut packed_instr = dag.dag()[node].unwrap_operation().clone();
240+
let mut packed_instr = dag[node].unwrap_operation().clone();
241241

242242
if packed_instr.op.control_flow() {
243243
let OperationRef::Instruction(py_instr) = packed_instr.op.view() else {
@@ -486,7 +486,7 @@ fn run_2q_unitary_synthesis(
486486
.topological_op_nodes()
487487
.expect("Unexpected error in dag.topological_op_nodes()")
488488
.map(|node| {
489-
let NodeType::Operation(inst) = &synth_dag.dag()[node] else {
489+
let NodeType::Operation(inst) = &synth_dag[node] else {
490490
unreachable!("DAG node must be an instruction")
491491
};
492492
let inst_qubits = synth_dag
@@ -1002,7 +1002,7 @@ fn synth_su4_dag(
10021002
Some(preferred_dir) => {
10031003
let mut synth_direction: Option<Vec<u32>> = None;
10041004
for node in synth_dag.topological_op_nodes()? {
1005-
let inst = &synth_dag.dag()[node].unwrap_operation();
1005+
let inst = &synth_dag[node].unwrap_operation();
10061006
if inst.op.num_qubits() == 2 {
10071007
let qargs = synth_dag.get_qargs(inst.qubits);
10081008
synth_direction = Some(vec![qargs[0].0, qargs[1].0]);
@@ -1066,7 +1066,7 @@ fn reversed_synth_su4_dag(
10661066
let mut target_dag = synth_dag.copy_empty_like(py, "alike")?;
10671067
let flip_bits: [Qubit; 2] = [Qubit(1), Qubit(0)];
10681068
for node in synth_dag.topological_op_nodes()? {
1069-
let mut inst = synth_dag.dag()[node].unwrap_operation().clone();
1069+
let mut inst = synth_dag[node].unwrap_operation().clone();
10701070
let qubits: Vec<Qubit> = synth_dag
10711071
.qargs_interner()
10721072
.get(inst.qubits)

crates/circuit/src/converters.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
#[cfg(feature = "cache_pygates")]
1414
use std::sync::OnceLock;
1515

16-
use ::pyo3::prelude::*;
1716
use hashbrown::HashMap;
17+
use pyo3::prelude::*;
1818
use pyo3::{
1919
intern,
2020
types::{PyDict, PyList},
@@ -106,7 +106,7 @@ pub fn dag_to_circuit(
106106
dag.qargs_interner().clone(),
107107
dag.cargs_interner().clone(),
108108
dag.topological_op_nodes()?.map(|node_index| {
109-
let NodeType::Operation(ref instr) = dag.dag()[node_index] else {
109+
let NodeType::Operation(ref instr) = dag[node_index] else {
110110
unreachable!(
111111
"The received node from topological_op_nodes() is not an Operation node."
112112
)

crates/circuit/src/dag_circuit.rs

+8
Original file line numberDiff line numberDiff line change
@@ -6967,6 +6967,14 @@ impl DAGCircuit {
69676967
}
69686968
}
69696969

6970+
impl ::std::ops::Index<NodeIndex> for DAGCircuit {
6971+
type Output = NodeType;
6972+
6973+
fn index(&self, index: NodeIndex) -> &Self::Output {
6974+
self.dag.index(index)
6975+
}
6976+
}
6977+
69706978
/// Add to global phase. Global phase can only be Float or ParameterExpression so this
69716979
/// does not handle the full possibility of parameter values.
69726980
pub(crate) fn add_global_phase(py: Python, phase: &Param, other: &Param) -> PyResult<Param> {

0 commit comments

Comments
 (0)