Skip to content

Commit 37c0780

Browse files
committed
Fix Python->Rust Param conversion
This commit adds a custom implementation of the FromPyObject trait for the Param enum. Previously, the Param trait derived it's impl of the trait, but this logic wasn't perfect. In cases whern a ParameterExpression was effectively a constant (such as `0 * x`) the trait's attempt to coerce to a float first would result in those ParameterExpressions being dropped from the circuit at insertion time. This was a change in behavior from before having gates in Rust as the parameters would disappear from the circuit at insertion time instead of at bind time. This commit fixes this by having a custom impl for FromPyObject that first tries to figure out if the parameter is a ParameterExpression (or a QuantumCircuit) by using a Python isinstance() check, then tries to extract it as a float, and finally stores a non-parameter object; which is a new variant in the Param enum. This new variant also lets us simplify the logic around adding gates to the parameter table as we're able to know ahead of time which gate parameters are `ParameterExpression`s and which are other objects (and don't need to be tracked in the parameter table. Additionally this commit tweaks two tests, the first is test.python.circuit.library.test_nlocal.TestNLocal.test_parameters_setter which was adjusted in the previous commit to workaround the bug fixed by this commit. The second is test.python.circuit.test_parameters which was testing that a bound ParameterExpression with a value of 0 defaults to an int which was a side effect of passing an int input to symengine for the bind value and not part of the api and didn't need to be checked. This assertion was removed from the test because the rust representation is only storing f64 values for the numeric parameters and it is never an int after binding from the Python perspective it isn't any different to have float(0) and int(0) unless you explicit isinstance check like the test previously was.
1 parent ad3e3c5 commit 37c0780

File tree

5 files changed

+46
-44
lines changed

5 files changed

+46
-44
lines changed

crates/circuit/src/circuit_data.rs

+3-40
Original file line numberDiff line numberDiff line change
@@ -280,33 +280,11 @@ impl CircuitData {
280280
let mut new_param = false;
281281
let inst_params = &self.data[inst_index].params;
282282
if let Some(raw_params) = inst_params {
283-
let param_mod =
284-
PyModule::import_bound(py, intern!(py, "qiskit.circuit.parameterexpression"))?;
285-
let param_class = param_mod.getattr(intern!(py, "ParameterExpression"))?;
286-
let circuit_mod =
287-
PyModule::import_bound(py, intern!(py, "qiskit.circuit.quantumcircuit"))?;
288-
let circuit_class = circuit_mod.getattr(intern!(py, "QuantumCircuit"))?;
289283
let params: Vec<(usize, PyObject)> = raw_params
290284
.iter()
291285
.enumerate()
292286
.filter_map(|(idx, x)| match x {
293-
Param::ParameterExpression(param_obj) => {
294-
if param_obj
295-
.clone_ref(py)
296-
.into_bound(py)
297-
.is_instance(&param_class)
298-
.unwrap()
299-
|| param_obj
300-
.clone_ref(py)
301-
.into_bound(py)
302-
.is_instance(&circuit_class)
303-
.unwrap()
304-
{
305-
Some((idx, param_obj.clone_ref(py)))
306-
} else {
307-
None
308-
}
309-
}
287+
Param::ParameterExpression(param_obj) => Some((idx, param_obj.clone_ref(py))),
310288
_ => None,
311289
})
312290
.collect();
@@ -370,23 +348,7 @@ impl CircuitData {
370348
.iter()
371349
.enumerate()
372350
.filter_map(|(idx, x)| match x {
373-
Param::ParameterExpression(param_obj) => {
374-
let param_mod =
375-
PyModule::import_bound(py, "qiskit.circuit.parameterexpression")
376-
.ok()?;
377-
let param_class =
378-
param_mod.getattr(intern!(py, "ParameterExpression")).ok()?;
379-
if param_obj
380-
.clone_ref(py)
381-
.into_bound(py)
382-
.is_instance(&param_class)
383-
.unwrap()
384-
{
385-
Some((idx, param_obj.clone_ref(py)))
386-
} else {
387-
None
388-
}
389-
}
351+
Param::ParameterExpression(param_obj) => Some((idx, param_obj.clone_ref(py))),
390352
_ => None,
391353
})
392354
.collect();
@@ -1131,6 +1093,7 @@ impl CircuitData {
11311093
}
11321094
self.global_phase = Param::ParameterExpression(angle);
11331095
}
1096+
Param::Obj(_) => return Err(PyValueError::new_err("Invalid type for global phase")),
11341097
};
11351098
Ok(())
11361099
}

crates/circuit/src/circuit_instruction.rs

+11
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,17 @@ impl CircuitInstruction {
452452
break;
453453
}
454454
}
455+
Param::Obj(val_a) => {
456+
if let Param::Obj(val_b) = param_b {
457+
if !val_a.bind(py).eq(val_b.bind(py))? {
458+
out = false;
459+
break;
460+
}
461+
} else {
462+
out = false;
463+
break;
464+
}
465+
}
455466
}
456467
}
457468
out

crates/circuit/src/operations.rs

+31-2
Original file line numberDiff line numberDiff line change
@@ -116,17 +116,41 @@ pub trait Operation {
116116
fn standard_gate(&self) -> Option<StandardGate>;
117117
}
118118

119-
#[derive(FromPyObject, Clone, Debug)]
119+
#[derive(Clone, Debug)]
120120
pub enum Param {
121-
Float(f64),
122121
ParameterExpression(PyObject),
122+
Float(f64),
123+
Obj(PyObject),
124+
}
125+
126+
impl<'py> FromPyObject<'py> for Param {
127+
fn extract_bound(b: &Bound<'py, PyAny>) -> Result<Self, PyErr> {
128+
let param_mod = PyModule::import_bound(
129+
b.py(),
130+
intern!(b.py(), "qiskit.circuit.parameterexpression"),
131+
)?;
132+
let param_class = param_mod.getattr(intern!(b.py(), "ParameterExpression"))?;
133+
let circuit_mod =
134+
PyModule::import_bound(b.py(), intern!(b.py(), "qiskit.circuit.quantumcircuit"))?;
135+
let circuit_class = circuit_mod.getattr(intern!(b.py(), "QuantumCircuit"))?;
136+
Ok(
137+
if b.is_instance(&param_class)? || b.is_instance(&circuit_class)? {
138+
Param::ParameterExpression(b.clone().unbind())
139+
} else if let Ok(val) = b.extract::<f64>() {
140+
Param::Float(val)
141+
} else {
142+
Param::Obj(b.clone().unbind())
143+
},
144+
)
145+
}
123146
}
124147

125148
impl IntoPy<PyObject> for Param {
126149
fn into_py(self, py: Python) -> PyObject {
127150
match &self {
128151
Self::Float(val) => val.to_object(py),
129152
Self::ParameterExpression(val) => val.clone_ref(py),
153+
Self::Obj(val) => val.clone_ref(py),
130154
}
131155
}
132156
}
@@ -136,6 +160,7 @@ impl ToPyObject for Param {
136160
match self {
137161
Self::Float(val) => val.to_object(py),
138162
Self::ParameterExpression(val) => val.clone_ref(py),
163+
Self::Obj(val) => val.clone_ref(py),
139164
}
140165
}
141166
}
@@ -328,14 +353,17 @@ impl Operation for StandardGate {
328353
let theta: Option<f64> = match params[0] {
329354
Param::Float(val) => Some(val),
330355
Param::ParameterExpression(_) => None,
356+
Param::Obj(_) => None,
331357
};
332358
let phi: Option<f64> = match params[1] {
333359
Param::Float(val) => Some(val),
334360
Param::ParameterExpression(_) => None,
361+
Param::Obj(_) => None,
335362
};
336363
let lam: Option<f64> = match params[2] {
337364
Param::Float(val) => Some(val),
338365
Param::ParameterExpression(_) => None,
366+
Param::Obj(_) => None,
339367
};
340368
// If let chains as needed here are unstable ignore clippy to
341369
// workaround. Upstream rust tracking issue:
@@ -471,6 +499,7 @@ impl Operation for StandardGate {
471499
)
472500
.expect("Unexpected Qiskit python bug"),
473501
),
502+
Param::Obj(_) => unreachable!(),
474503
}
475504
}),
476505
Self::ECRGate => todo!("Add when we have RZX"),

test/python/circuit/library/test_nlocal.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def test_parameters_setter(self, params):
235235
initial_params = ParameterVector("p", length=6)
236236
circuit = QuantumCircuit(1)
237237
for i, initial_param in enumerate(initial_params):
238-
circuit.ry((i + 1) * initial_param, 0)
238+
circuit.ry(i * initial_param, 0)
239239

240240
# create an NLocal from the circuit and set the new parameters
241241
nlocal = NLocal(1, entanglement_blocks=circuit, reps=1)

test/python/circuit/test_parameters.py

-1
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,6 @@ def test_expression_partial_binding_zero(self):
585585
fbqc = pqc.assign_parameters({phi: 1})
586586

587587
self.assertEqual(fbqc.parameters, set())
588-
self.assertIsInstance(fbqc.data[0].operation.params[0], int)
589588
self.assertEqual(float(fbqc.data[0].operation.params[0]), 0)
590589

591590
def test_raise_if_assigning_params_not_in_circuit(self):

0 commit comments

Comments
 (0)