Skip to content

Commit 2d3db9a

Browse files
authored
Allow immutable borrow to access QuantumCircuit.parameters (#12918)
* Allow immutable borrow to access `QuantumCircuit.parameters` `QuantumCircuit.parameters` is logically a read-only operation on `QuantumCircuit`. For efficiency in multiple calls to `assign_parameters`, we actually cache the sort order of the internal `ParameterTable` on access. This is purely a caching effect, and should not leak out to users. The previous implementation took a Rust-space mutable borrow out in order to (potentially) mutate the cache. This caused problems if multiple Python threads attempted to call `assign_parameters` simultaneously; it was possible for one thread to give up the GIL during its initial call to `CircuitData::copy` (so an immutable borrow was still live), allowing another thread to continue on to the getter `CircuitData::get_parameters`, which required a mutable borrow, which failed due to the paused thread in `copy`. This moves the cache into a `RefCell`, allowing the parameter getters to take an immutable borrow as the receiver. We now write the cache out only if we *can* take the mutable borrow out necessary. This can mean that other threads will have to repeat the work of re-sorting the parameters, because their borrows were blocking the saving of the cache, but this will not cause failures. The methods on `ParameterTable` that invalidate the cache all require a mutable borrow on the table itself. This makes it impossible for an immutable borrow to exist simultaneously on the cache, so these methods should always succeed to acquire the cache lock to invalidate it. * Use `RefCell::get_mut` where possible In several cases, the previous code was using the runtime-checked `RefCell::borrow_mut` in locations that can be statically proven to be safe to take the mutable reference. Using the correct function for this makes the logic clearer (as well as technically removing a small amount of runtime overhead). * Use `OnceCell` instead of `RefCell` `OnceCell` has less runtime checking than `RefCell` (only whether it is initialised or not, which is an `Option` check), and better represents the dynamic extensions to the borrow checker that we actually need for the caching in this method. All methods that can invalidate the cache all necessarily take `&mut ParameterTable` already, since they will modify Rust-space data. A `OnceCell` can be deinitialised through a mutable reference, so this is fine. The only reason a `&ParameterTable` method would need to mutate the cache is to create it, which is the allowed set of `OnceCell` operations.
1 parent 592c5f4 commit 2d3db9a

File tree

2 files changed

+78
-53
lines changed

2 files changed

+78
-53
lines changed

crates/circuit/src/circuit_data.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ impl CircuitData {
459459
/// Get a (cached) sorted list of the Python-space `Parameter` instances tracked by this circuit
460460
/// data's parameter table.
461461
#[getter]
462-
pub fn get_parameters<'py>(&mut self, py: Python<'py>) -> Bound<'py, PyList> {
462+
pub fn get_parameters<'py>(&self, py: Python<'py>) -> Bound<'py, PyList> {
463463
self.param_table.py_parameters(py)
464464
}
465465

crates/circuit/src/parameter_table.rs

+77-52
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
// copyright notice, and modified files need to carry a notice indicating
1111
// that they have been altered from the originals.
1212

13+
use std::cell::OnceCell;
14+
1315
use hashbrown::hash_map::Entry;
1416
use hashbrown::{HashMap, HashSet};
1517
use thiserror::Error;
@@ -123,18 +125,17 @@ pub struct ParameterTable {
123125
by_name: HashMap<PyBackedStr, ParameterUuid>,
124126
/// Additional information on any `ParameterVector` instances that have elements in the circuit.
125127
vectors: HashMap<VectorUuid, VectorInfo>,
126-
/// Sort order of the parameters. This is lexicographical for most parameters, except elements
127-
/// of a `ParameterVector` are sorted within the vector by numerical index. We calculate this
128-
/// on demand and cache it; an empty `order` implies it is not currently calculated. We don't
129-
/// use `Option<Vec>` so we can re-use the allocation for partial parameter bindings.
128+
/// Cache of the sort order of the parameters. This is lexicographical for most parameters,
129+
/// except elements of a `ParameterVector` are sorted within the vector by numerical index. We
130+
/// calculate this on demand and cache it.
130131
///
131-
/// Any method that adds or a removes a parameter is responsible for invalidating this cache.
132-
order: Vec<ParameterUuid>,
132+
/// Any method that adds or removes a parameter needs to invalidate this.
133+
order_cache: OnceCell<Vec<ParameterUuid>>,
133134
/// Cache of a Python-space list of the parameter objects, in order. We only generate this
134135
/// specifically when asked.
135136
///
136-
/// Any method that adds or a removes a parameter is responsible for invalidating this cache.
137-
py_parameters: Option<Py<PyList>>,
137+
/// Any method that adds or removes a parameter needs to invalidate this.
138+
py_parameters_cache: OnceCell<Py<PyList>>,
138139
}
139140

140141
impl ParameterTable {
@@ -194,8 +195,6 @@ impl ParameterTable {
194195
None
195196
};
196197
self.by_name.insert(name.clone(), uuid);
197-
self.order.clear();
198-
self.py_parameters = None;
199198
let mut uses = HashSet::new();
200199
if let Some(usage) = usage {
201200
uses.insert_unique_unchecked(usage);
@@ -206,6 +205,7 @@ impl ParameterTable {
206205
element,
207206
object: param_ob.clone().unbind(),
208207
});
208+
self.invalidate_cache();
209209
}
210210
}
211211
Ok(uuid)
@@ -231,43 +231,39 @@ impl ParameterTable {
231231
}
232232

233233
/// Get the (maybe cached) Python list of the sorted `Parameter` objects.
234-
pub fn py_parameters<'py>(&mut self, py: Python<'py>) -> Bound<'py, PyList> {
235-
if let Some(py_parameters) = self.py_parameters.as_ref() {
236-
return py_parameters.clone_ref(py).into_bound(py);
237-
}
238-
self.ensure_sorted();
239-
let out = PyList::new_bound(
240-
py,
241-
self.order
242-
.iter()
243-
.map(|uuid| self.by_uuid[uuid].object.clone_ref(py).into_bound(py)),
244-
);
245-
self.py_parameters = Some(out.clone().unbind());
246-
out
234+
pub fn py_parameters<'py>(&self, py: Python<'py>) -> Bound<'py, PyList> {
235+
self.py_parameters_cache
236+
.get_or_init(|| {
237+
PyList::new_bound(
238+
py,
239+
self.order_cache
240+
.get_or_init(|| self.sorted_order())
241+
.iter()
242+
.map(|uuid| self.by_uuid[uuid].object.bind(py).clone()),
243+
)
244+
.unbind()
245+
})
246+
.bind(py)
247+
.clone()
247248
}
248249

249250
/// Get a Python set of all tracked `Parameter` objects.
250251
pub fn py_parameters_unsorted<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PySet>> {
251252
PySet::new_bound(py, self.by_uuid.values().map(|info| &info.object))
252253
}
253254

254-
/// Ensure that the `order` field is populated and sorted.
255-
fn ensure_sorted(&mut self) {
256-
// If `order` is already populated, it's sorted; it's the responsibility of the methods of
257-
// this struct that mutate it to invalidate the cache.
258-
if !self.order.is_empty() {
259-
return;
260-
}
261-
self.order.reserve(self.by_uuid.len());
262-
self.order.extend(self.by_uuid.keys());
263-
self.order.sort_unstable_by_key(|uuid| {
255+
/// Get the sorted order of the `ParameterTable`. This does not access the cache.
256+
fn sorted_order(&self) -> Vec<ParameterUuid> {
257+
let mut out = self.by_uuid.keys().copied().collect::<Vec<_>>();
258+
out.sort_unstable_by_key(|uuid| {
264259
let info = &self.by_uuid[uuid];
265260
if let Some(vec) = info.element.as_ref() {
266261
(&self.vectors[&vec.vector_uuid].name, vec.index)
267262
} else {
268263
(&info.name, 0)
269264
}
270-
})
265+
});
266+
out
271267
}
272268

273269
/// Add a use of a parameter to the table.
@@ -310,9 +306,8 @@ impl ParameterTable {
310306
vec_entry.remove_entry();
311307
}
312308
}
313-
self.order.clear();
314-
self.py_parameters = None;
315309
entry.remove_entry();
310+
self.invalidate_cache();
316311
}
317312
Ok(())
318313
}
@@ -337,26 +332,28 @@ impl ParameterTable {
337332
(vector_info.refcount > 0).then_some(vector_info)
338333
});
339334
}
340-
self.order.clear();
341-
self.py_parameters = None;
335+
self.invalidate_cache();
342336
Ok(info.uses)
343337
}
344338

345339
/// Clear this table, yielding the Python parameter objects and their uses in sorted order.
340+
///
341+
/// The clearing effect is eager and not dependent on the iteration.
346342
pub fn drain_ordered(
347-
&'_ mut self,
348-
) -> impl Iterator<Item = (Py<PyAny>, HashSet<ParameterUse>)> + '_ {
349-
self.ensure_sorted();
343+
&mut self,
344+
) -> impl ExactSizeIterator<Item = (Py<PyAny>, HashSet<ParameterUse>)> {
345+
let order = self
346+
.order_cache
347+
.take()
348+
.unwrap_or_else(|| self.sorted_order());
349+
let by_uuid = ::std::mem::take(&mut self.by_uuid);
350350
self.by_name.clear();
351351
self.vectors.clear();
352-
self.py_parameters = None;
353-
self.order.drain(..).map(|uuid| {
354-
let info = self
355-
.by_uuid
356-
.remove(&uuid)
357-
.expect("tracked UUIDs should be consistent");
358-
(info.object, info.uses)
359-
})
352+
self.py_parameters_cache.take();
353+
ParameterTableDrain {
354+
order: order.into_iter(),
355+
by_uuid,
356+
}
360357
}
361358

362359
/// Empty this `ParameterTable` of all its contents. This does not affect the capacities of the
@@ -365,8 +362,12 @@ impl ParameterTable {
365362
self.by_uuid.clear();
366363
self.by_name.clear();
367364
self.vectors.clear();
368-
self.order.clear();
369-
self.py_parameters = None;
365+
self.invalidate_cache();
366+
}
367+
368+
fn invalidate_cache(&mut self) {
369+
self.order_cache.take();
370+
self.py_parameters_cache.take();
370371
}
371372

372373
/// Expose the tracked data for a given parameter as directly as possible to Python space.
@@ -401,9 +402,33 @@ impl ParameterTable {
401402
visit.call(&info.object)?
402403
}
403404
// We don't need to / can't visit the `PyBackedStr` stores.
404-
if let Some(list) = self.py_parameters.as_ref() {
405+
if let Some(list) = self.py_parameters_cache.get() {
405406
visit.call(list)?
406407
}
407408
Ok(())
408409
}
409410
}
411+
412+
struct ParameterTableDrain {
413+
order: ::std::vec::IntoIter<ParameterUuid>,
414+
by_uuid: HashMap<ParameterUuid, ParameterInfo>,
415+
}
416+
impl Iterator for ParameterTableDrain {
417+
type Item = (Py<PyAny>, HashSet<ParameterUse>);
418+
419+
fn next(&mut self) -> Option<Self::Item> {
420+
self.order.next().map(|uuid| {
421+
let info = self
422+
.by_uuid
423+
.remove(&uuid)
424+
.expect("tracked UUIDs should be consistent");
425+
(info.object, info.uses)
426+
})
427+
}
428+
429+
fn size_hint(&self) -> (usize, Option<usize>) {
430+
self.order.size_hint()
431+
}
432+
}
433+
impl ExactSizeIterator for ParameterTableDrain {}
434+
impl ::std::iter::FusedIterator for ParameterTableDrain {}

0 commit comments

Comments
 (0)