Skip to content

Commit 52ba2a9

Browse files
authored
use Arc::ptr_eq over .is() (Qiskit#13754)
1 parent d03a61e commit 52ba2a9

File tree

1 file changed

+33
-27
lines changed

1 file changed

+33
-27
lines changed

crates/accelerate/src/sparse_observable.rs

+33-27
Original file line numberDiff line numberDiff line change
@@ -2956,23 +2956,23 @@ impl PySparseObservable {
29562956
other: &Bound<'py, PyAny>,
29572957
) -> PyResult<Bound<'py, PyAny>> {
29582958
let py = slf_.py();
2959-
if slf_.is(other) {
2959+
let Some(other) = coerce_to_observable(other)? else {
2960+
return Ok(py.NotImplemented().into_bound(py));
2961+
};
2962+
2963+
let other = other.borrow();
2964+
let slf_ = slf_.borrow();
2965+
if Arc::ptr_eq(&slf_.inner, &other.inner) {
29602966
// This fast path is for consistency with the in-place `__iadd__`, which would otherwise
29612967
// struggle to do the addition to itself.
2962-
let slf_ = slf_.borrow();
29632968
let inner = slf_.inner.read().map_err(|_| InnerReadError)?;
29642969
return <&SparseObservable as ::std::ops::Mul<_>>::mul(
29652970
&inner,
29662971
Complex64::new(2.0, 0.0),
29672972
)
29682973
.into_bound_py_any(py);
29692974
}
2970-
let Some(other) = coerce_to_observable(other)? else {
2971-
return Ok(py.NotImplemented().into_bound(py));
2972-
};
2973-
let slf_ = slf_.borrow();
29742975
let slf_inner = slf_.inner.read().map_err(|_| InnerReadError)?;
2975-
let other = other.borrow();
29762976
let other_inner = other.inner.read().map_err(|_| InnerReadError)?;
29772977
slf_inner.check_equal_qubits(&other_inner)?;
29782978
<&SparseObservable as ::std::ops::Add>::add(&slf_inner, &other_inner).into_bound_py_any(py)
@@ -2992,13 +2992,7 @@ impl PySparseObservable {
29922992
<&SparseObservable as ::std::ops::Add>::add(&other_inner, &inner).into_bound_py_any(py)
29932993
}
29942994

2995-
fn __iadd__(slf_: Bound<Self>, other: &Bound<PyAny>) -> PyResult<()> {
2996-
if slf_.is(other) {
2997-
let slf_ = slf_.borrow();
2998-
let mut slf_inner = slf_.inner.write().map_err(|_| InnerWriteError)?;
2999-
*slf_inner *= Complex64::new(2.0, 0.0);
3000-
return Ok(());
3001-
}
2995+
fn __iadd__(slf_: Bound<PySparseObservable>, other: &Bound<PyAny>) -> PyResult<()> {
30022996
let Some(other) = coerce_to_observable(other)? else {
30032997
// This is not well behaved - we _should_ return `NotImplemented` to Python space
30042998
// without an exception, but limitations in PyO3 prevent this at the moment. See
@@ -3008,9 +3002,18 @@ impl PySparseObservable {
30083002
other.repr()?
30093003
)));
30103004
};
3005+
3006+
let other = other.borrow();
30113007
let slf_ = slf_.borrow();
30123008
let mut slf_inner = slf_.inner.write().map_err(|_| InnerWriteError)?;
3013-
let other = other.borrow();
3009+
3010+
// Check if slf_ and other point to the same SparseObservable object, in which case
3011+
// we just multiply it by 2
3012+
if Arc::ptr_eq(&slf_.inner, &other.inner) {
3013+
*slf_inner *= Complex64::new(2.0, 0.0);
3014+
return Ok(());
3015+
}
3016+
30143017
let other_inner = other.inner.read().map_err(|_| InnerReadError)?;
30153018
slf_inner.check_equal_qubits(&other_inner)?;
30163019
slf_inner.add_assign(&other_inner);
@@ -3022,16 +3025,17 @@ impl PySparseObservable {
30223025
other: &Bound<'py, PyAny>,
30233026
) -> PyResult<Bound<'py, PyAny>> {
30243027
let py = slf_.py();
3025-
if slf_.is(other) {
3026-
return PySparseObservable::zero(slf_.borrow().num_qubits()?).into_bound_py_any(py);
3027-
}
30283028
let Some(other) = coerce_to_observable(other)? else {
30293029
return Ok(py.NotImplemented().into_bound(py));
30303030
};
30313031

3032+
let other = other.borrow();
30323033
let slf_ = slf_.borrow();
3034+
if Arc::ptr_eq(&slf_.inner, &other.inner) {
3035+
return PySparseObservable::zero(slf_.num_qubits()?).into_bound_py_any(py);
3036+
}
3037+
30333038
let slf_inner = slf_.inner.read().map_err(|_| InnerReadError)?;
3034-
let other = other.borrow();
30353039
let other_inner = other.inner.read().map_err(|_| InnerReadError)?;
30363040
slf_inner.check_equal_qubits(&other_inner)?;
30373041
<&SparseObservable as ::std::ops::Sub>::sub(&slf_inner, &other_inner).into_bound_py_any(py)
@@ -3050,13 +3054,6 @@ impl PySparseObservable {
30503054
}
30513055

30523056
fn __isub__(slf_: Bound<PySparseObservable>, other: &Bound<PyAny>) -> PyResult<()> {
3053-
if slf_.is(other) {
3054-
// This is not strictly the same thing as `a - a` if `a` contains non-finite
3055-
// floating-point values (`inf - inf` is `NaN`, for example); we don't really have a
3056-
// clear view on what floating-point guarantees we're going to make right now.
3057-
slf_.borrow_mut().clear()?;
3058-
return Ok(());
3059-
}
30603057
let Some(other) = coerce_to_observable(other)? else {
30613058
// This is not well behaved - we _should_ return `NotImplemented` to Python space
30623059
// without an exception, but limitations in PyO3 prevent this at the moment. See
@@ -3066,9 +3063,18 @@ impl PySparseObservable {
30663063
other.repr()?
30673064
)));
30683065
};
3066+
let other = other.borrow();
30693067
let slf_ = slf_.borrow();
30703068
let mut slf_inner = slf_.inner.write().map_err(|_| InnerWriteError)?;
3071-
let other = other.borrow();
3069+
3070+
if Arc::ptr_eq(&slf_.inner, &other.inner) {
3071+
// This is not strictly the same thing as `a - a` if `a` contains non-finite
3072+
// floating-point values (`inf - inf` is `NaN`, for example); we don't really have a
3073+
// clear view on what floating-point guarantees we're going to make right now.
3074+
slf_inner.clear();
3075+
return Ok(());
3076+
}
3077+
30723078
let other_inner = other.inner.read().map_err(|_| InnerReadError)?;
30733079
slf_inner.check_equal_qubits(&other_inner)?;
30743080
slf_inner.sub_assign(&other_inner);

0 commit comments

Comments
 (0)