Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use average gate fidelity in the commutation checker #13874

Merged
merged 22 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions crates/accelerate/src/commutation_analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ pub(crate) fn analyze_commutations_inner(
py: Python,
dag: &mut DAGCircuit,
commutation_checker: &mut CommutationChecker,
approximation_degree: f64,
) -> PyResult<(CommutationSet, NodeIndices)> {
let mut commutation_set: CommutationSet = HashMap::new();
let mut node_indices: NodeIndices = HashMap::new();
Expand Down Expand Up @@ -102,6 +103,7 @@ pub(crate) fn analyze_commutations_inner(
qargs2,
cargs2,
MAX_NUM_QUBITS,
approximation_degree,
)?;
if !all_commute {
break;
Expand Down Expand Up @@ -132,17 +134,19 @@ pub(crate) fn analyze_commutations_inner(
}

#[pyfunction]
#[pyo3(signature = (dag, commutation_checker))]
#[pyo3(signature = (dag, commutation_checker, approximation_degree=1.))]
pub(crate) fn analyze_commutations(
py: Python,
dag: &mut DAGCircuit,
commutation_checker: &mut CommutationChecker,
approximation_degree: f64,
) -> PyResult<Py<PyDict>> {
// This returns two HashMaps:
// * The commuting nodes per wire: {wire: [commuting_nodes_1, commuting_nodes_2, ...]}
// * The index in which commutation set a given node is located on a wire: {(node, wire): index}
// The Python dict will store both of these dictionaries in one.
let (commutation_set, node_indices) = analyze_commutations_inner(py, dag, commutation_checker)?;
let (commutation_set, node_indices) =
analyze_commutations_inner(py, dag, commutation_checker, approximation_degree)?;

let out_dict = PyDict::new(py);

Expand Down
6 changes: 4 additions & 2 deletions crates/accelerate/src/commutation_cancellation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,13 @@ struct CancellationSetKey {
}

#[pyfunction]
#[pyo3(signature = (dag, commutation_checker, basis_gates=None))]
#[pyo3(signature = (dag, commutation_checker, basis_gates=None, approximation_degree=1.))]
pub(crate) fn cancel_commutations(
py: Python,
dag: &mut DAGCircuit,
commutation_checker: &mut CommutationChecker,
basis_gates: Option<HashSet<String>>,
approximation_degree: f64,
) -> PyResult<()> {
let basis: HashSet<String> = if let Some(basis) = basis_gates {
basis
Expand Down Expand Up @@ -97,7 +98,8 @@ pub(crate) fn cancel_commutations(
sec_commutation_set_id), the value is the list gates that share the same gate type,
qubits and commutation sets.
*/
let (commutation_set, node_indices) = analyze_commutations_inner(py, dag, commutation_checker)?;
let (commutation_set, node_indices) =
analyze_commutations_inner(py, dag, commutation_checker, approximation_degree)?;
let mut cancellation_sets: HashMap<CancellationSetKey, Vec<NodeIndex>> = HashMap::new();

(0..dag.num_qubits() as u32).for_each(|qubit| {
Expand Down
205 changes: 82 additions & 123 deletions crates/accelerate/src/commutation_checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use hashbrown::{HashMap, HashSet};
use ndarray::linalg::kron;
use ndarray::Array2;
use num_complex::Complex64;
use num_complex::ComplexFloat;
use once_cell::sync::Lazy;
use smallvec::SmallVec;

Expand All @@ -34,6 +35,7 @@ use qiskit_circuit::operations::{
};
use qiskit_circuit::{BitType, Clbit, Qubit};

use crate::gate_metrics;
use crate::unitary_compose;
use crate::QiskitError;

Expand All @@ -54,48 +56,30 @@ static SUPPORTED_OP: Lazy<HashSet<&str>> = Lazy::new(|| {
// and their pi-periodicity. Here we mean a gate is n-pi periodic, if for angles that are
// multiples of n*pi, the gate is equal to the identity up to a global phase.
// E.g. RX is generated by X and 2-pi periodic, while CRX is generated by CX and 4-pi periodic.
static SUPPORTED_ROTATIONS: Lazy<HashMap<&str, (u8, Option<OperationRef>)>> = Lazy::new(|| {
static SUPPORTED_ROTATIONS: Lazy<HashMap<&str, Option<OperationRef>>> = Lazy::new(|| {
HashMap::from([
(
"rx",
(2, Some(OperationRef::StandardGate(StandardGate::XGate))),
),
(
"ry",
(2, Some(OperationRef::StandardGate(StandardGate::YGate))),
),
(
"rz",
(2, Some(OperationRef::StandardGate(StandardGate::ZGate))),
),
(
"p",
(2, Some(OperationRef::StandardGate(StandardGate::ZGate))),
),
(
"u1",
(2, Some(OperationRef::StandardGate(StandardGate::ZGate))),
),
("rxx", (2, None)), // None means the gate is in the commutation dictionary
("ryy", (2, None)),
("rzx", (2, None)),
("rzz", (2, None)),
("rx", Some(OperationRef::StandardGate(StandardGate::XGate))),
("ry", Some(OperationRef::StandardGate(StandardGate::YGate))),
("rz", Some(OperationRef::StandardGate(StandardGate::ZGate))),
("p", Some(OperationRef::StandardGate(StandardGate::ZGate))),
("u1", Some(OperationRef::StandardGate(StandardGate::ZGate))),
("rxx", None), // None means the gate is in the commutation dictionary
("ryy", None),
("rzx", None),
("rzz", None),
(
"crx",
(4, Some(OperationRef::StandardGate(StandardGate::CXGate))),
Some(OperationRef::StandardGate(StandardGate::CXGate)),
),
(
"cry",
(4, Some(OperationRef::StandardGate(StandardGate::CYGate))),
Some(OperationRef::StandardGate(StandardGate::CYGate)),
),
(
"crz",
(4, Some(OperationRef::StandardGate(StandardGate::CZGate))),
),
(
"cp",
(2, Some(OperationRef::StandardGate(StandardGate::CZGate))),
Some(OperationRef::StandardGate(StandardGate::CZGate)),
),
("cp", Some(OperationRef::StandardGate(StandardGate::CZGate))),
])
});

Expand Down Expand Up @@ -155,13 +139,14 @@ impl CommutationChecker {
}
}

#[pyo3(signature=(op1, op2, max_num_qubits=3))]
#[pyo3(signature=(op1, op2, max_num_qubits=3, approximation_degree=1.))]
fn commute_nodes(
&mut self,
py: Python,
op1: &DAGOpNode,
op2: &DAGOpNode,
max_num_qubits: u32,
approximation_degree: f64,
) -> PyResult<bool> {
let (qargs1, qargs2) = get_bits::<Qubit>(
py,
Expand All @@ -185,10 +170,11 @@ impl CommutationChecker {
&qargs2,
&cargs2,
max_num_qubits,
approximation_degree,
)
}

#[pyo3(signature=(op1, qargs1, cargs1, op2, qargs2, cargs2, max_num_qubits=3))]
#[pyo3(signature=(op1, qargs1, cargs1, op2, qargs2, cargs2, max_num_qubits=3, approximation_degree=1.))]
#[allow(clippy::too_many_arguments)]
fn commute(
&mut self,
Expand All @@ -200,6 +186,7 @@ impl CommutationChecker {
qargs2: Option<&Bound<PySequence>>,
cargs2: Option<&Bound<PySequence>>,
max_num_qubits: u32,
approximation_degree: f64,
) -> PyResult<bool> {
let qargs1 = qargs1.map_or_else(|| Ok(PyTuple::empty(py)), PySequenceMethods::to_tuple)?;
let cargs1 = cargs1.map_or_else(|| Ok(PyTuple::empty(py)), PySequenceMethods::to_tuple)?;
Expand All @@ -220,6 +207,7 @@ impl CommutationChecker {
&qargs2,
&cargs2,
max_num_qubits,
approximation_degree,
)
}

Expand Down Expand Up @@ -288,20 +276,20 @@ impl CommutationChecker {
qargs2: &[Qubit],
cargs2: &[Clbit],
max_num_qubits: u32,
approximation_degree: f64,
) -> PyResult<bool> {
// relative and absolute tolerance used to (1) check whether rotation gates commute
// trivially (i.e. the rotation angle is so small we assume it commutes) and (2) define
// comparison for the matrix-based commutation checks
let rtol = 1e-5;
let atol = 1e-8;
// If the average gate infidelity is below this tolerance, they commute. The tolerance
// is set to max(1e-12, 1 - approximation_degree), to account for roundoffs and for
// consistency with other places in Qiskit.
let tol = 1e-12_f64.max(1. - approximation_degree);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does mean we can't approximate more than what a 1e-12 tolerance provides. But it's the same logic we use in other places so I think that's fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it is unfortunately. I don't like this very much either, but currently approximation_degree is 1. everywhere per default which doesn't really allow us to set another default like 1-1e-12. The other option would've been to use None as indicator, but that already means that the target error rates should be used.


// if we have rotation gates, we attempt to map them to their generators, for example
// RX -> X or CPhase -> CZ
let (op1, params1, trivial1) = map_rotation(op1, params1, rtol);
let (op1, params1, trivial1) = map_rotation(op1, params1, tol);
if trivial1 {
return Ok(true);
}
let (op2, params2, trivial2) = map_rotation(op2, params2, rtol);
let (op2, params2, trivial2) = map_rotation(op2, params2, tol);
if trivial2 {
return Ok(true);
}
Expand Down Expand Up @@ -367,8 +355,7 @@ impl CommutationChecker {
second_op,
second_params,
second_qargs,
rtol,
atol,
tol,
);
}

Expand Down Expand Up @@ -403,8 +390,7 @@ impl CommutationChecker {
second_op,
second_params,
second_qargs,
rtol,
atol,
tol,
)?;

// TODO: implement a LRU cache for this
Expand Down Expand Up @@ -439,8 +425,7 @@ impl CommutationChecker {
second_op: &OperationRef,
second_params: &[Param],
second_qargs: &[Qubit],
rtol: f64,
atol: f64,
tol: f64,
) -> PyResult<bool> {
// Compute relative positioning of qargs of the second gate to the first gate.
// Since the qargs come out the same BitData, we already know there are no accidential
Expand Down Expand Up @@ -481,81 +466,49 @@ impl CommutationChecker {
None => return Ok(false),
};

if first_qarg == second_qarg {
match first_qarg.len() {
1 => Ok(unitary_compose::commute_1q(
&first_mat.view(),
&second_mat.view(),
rtol,
atol,
)),
2 => Ok(unitary_compose::commute_2q(
&first_mat.view(),
&second_mat.view(),
&[Qubit(0), Qubit(1)],
rtol,
atol,
)),
_ => Ok(unitary_compose::allclose(
&second_mat.dot(&first_mat).view(),
&first_mat.dot(&second_mat).view(),
rtol,
atol,
)),
}
// TODO Optimize this bit to avoid unnecessary Kronecker products:
// 1. We currently sort the operations for the cache by operation size, putting the
// *smaller* operation first: (smaller op, larger op)
// 2. This code here expands the first op to match the second -- hence we always
// match the operator sizes.
// This whole extension logic could be avoided since we know the second one is larger.
let extra_qarg2 = num_qubits - first_qarg.len() as u32;
let first_mat = if extra_qarg2 > 0 {
let id_op = Array2::<Complex64>::eye(usize::pow(2, extra_qarg2));
kron(&id_op, &first_mat)
} else {
// TODO Optimize this bit to avoid unnecessary Kronecker products:
// 1. We currently sort the operations for the cache by operation size, putting the
// *smaller* operation first: (smaller op, larger op)
// 2. This code here expands the first op to match the second -- hence we always
// match the operator sizes.
// This whole extension logic could be avoided since we know the second one is larger.
let extra_qarg2 = num_qubits - first_qarg.len() as u32;
let first_mat = if extra_qarg2 > 0 {
let id_op = Array2::<Complex64>::eye(usize::pow(2, extra_qarg2));
kron(&id_op, &first_mat)
} else {
first_mat
};

// the 1 qubit case cannot happen, since that would already have been captured
// by the previous if clause; first_qarg == second_qarg (if they overlap they must
// be the same)
if num_qubits == 2 {
return Ok(unitary_compose::commute_2q(
&first_mat.view(),
&second_mat.view(),
&second_qarg,
rtol,
atol,
));
};
first_mat
};

let op12 = match unitary_compose::compose(
&first_mat.view(),
&second_mat.view(),
&second_qarg,
false,
) {
Ok(matrix) => matrix,
Err(e) => return Err(PyRuntimeError::new_err(e)),
};
let op21 = match unitary_compose::compose(
&first_mat.view(),
&second_mat.view(),
&second_qarg,
true,
) {
Ok(matrix) => matrix,
Err(e) => return Err(PyRuntimeError::new_err(e)),
};
Ok(unitary_compose::allclose(
&op12.view(),
&op21.view(),
rtol,
atol,
))
}
// the 1 qubit case cannot happen, since that would already have been captured
// by the previous if clause; first_qarg == second_qarg (if they overlap they must
// be the same)
let op12 = match unitary_compose::compose(
&first_mat.view(),
&second_mat.view(),
&second_qarg,
false,
) {
Ok(matrix) => matrix,
Err(e) => return Err(PyRuntimeError::new_err(e)),
};
let op21 = match unitary_compose::compose(
&first_mat.view(),
&second_mat.view(),
&second_qarg,
true,
) {
Ok(matrix) => matrix,
Err(e) => return Err(PyRuntimeError::new_err(e)),
};
let (fid, phase) = gate_metrics::gate_fidelity(&op12.view(), &op21.view(), None);

// we consider the gates as commuting if the process fidelity of
// AB (BA)^\dagger is approximately the identity and there is no global phase difference
// let dim = op12.ncols() as f64;
// let matrix_tol = tol * dim.powi(2);
let matrix_tol = tol;
Ok(phase.abs() <= tol && (1.0 - fid).abs() <= matrix_tol)
}

fn clear_cache(&mut self) {
Expand Down Expand Up @@ -652,13 +605,19 @@ fn map_rotation<'a>(
) -> (&'a OperationRef<'a>, &'a [Param], bool) {
let name = op.name();

if let Some((pi_multiple, generator)) = SUPPORTED_ROTATIONS.get(name) {
if let Some(generator) = SUPPORTED_ROTATIONS.get(name) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe not for this PR but we can drop the hashset here all together and have a static lookup table based on the standard gates. Something like:

const fn builld_lut() -> [Option<StandardGate>; STANDARD_GATE_SIZE]  {
    ...
}
let static supported_rotations: [Option<StandardGate>; STANDARD_GATE_SIZE] = build_lut()

Alternatively you could just do a couple of if matches!() on .standard_gate().

When in rust space and working solely with standard gates we really never have a need for strings, so this just sticks out to me every time I see it. But it was pre-existing so we can do this in a follow up.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that sounds like a good idea, I'll put it on the follow-up list if that's good for you 🙂

// If the rotation angle is below the tolerance, the gate is assumed to
// commute with everything, and we simply return the operation with the flag that
// it commutes trivially.
if let Param::Float(angle) = params[0] {
let periodicity = (*pi_multiple as f64) * ::std::f64::consts::PI;
if (angle % periodicity).abs() < tol {
let gate = op
.standard_gate()
.expect("Supported gates are standard gates");
let (tr_over_dim, dim) = gate_metrics::rotation_trace_and_dim(gate, angle)
.expect("All rotation should be covered at this point");
let gate_fidelity = tr_over_dim.abs().powi(2);
let process_fidelity = (dim * gate_fidelity + 1.) / (dim + 1.);
if (1. - process_fidelity).abs() <= tol {
return (op, params, true);
};
};
Expand Down
Loading