Skip to content

Commit 59a1e05

Browse files
mudit2812vincentmralbi3rogithub-actions[bot]astralcai
authored
Improve support for Torch and Jax with dynamic_one_shot (#5672)
**Context:** Opened in favour of #5630. Bug fix for #5442. This PR updates `dynamic_one_shot` so that it has better compatibility with the `torch` and `jax` interfaces. **Description of the Change:** * Change casting method from `array.astype()` to `qml.math.cast` in the `apply_operation` dispatch for `MidMeasureMP`. * Update usage of `qml.math` in `dynamic_one_shot`. * When using `qml.counts`, cast results to ints before converting to strings for lists of MCM values and floats for single MCM values. This is needed because jax arrays are not hashable, and the hash of torch tensors seems to be independent of the value(s) stored inside it. Thus, neither can be used as keys for dictionaries. **Benefits:** Better interface support with `dynamic_one_shot`. **Possible Drawbacks:** **Related GitHub Issues:** --------- Co-authored-by: Vincent Michaud-Rioux <vincent.michaud-rioux@xanadu.ai> Co-authored-by: Vincent Michaud-Rioux <vincentm@nanoacademic.com> Co-authored-by: Christina Lee <christina@xanadu.ai> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Astral Cai <astral.cai@xanadu.ai> Co-authored-by: David Wierichs <david.wierichs@xanadu.ai> Co-authored-by: lillian542 <38584660+lillian542@users.noreply.github.com> Co-authored-by: Pietropaolo Frisoni <pietropaolo.frisoni@xanadu.ai> Co-authored-by: Korbinian Kottmann <43949391+Qottmann@users.noreply.github.com> Co-authored-by: Jay Soni <jbsoni@uwaterloo.ca> Co-authored-by: Guillermo Alonso-Linaje <65235481+KetpuntoG@users.noreply.github.com> Co-authored-by: Thomas R. Bromley <49409390+trbromley@users.noreply.github.com> Co-authored-by: Isaac De Vlugt <34751083+isaacdevlugt@users.noreply.github.com> Co-authored-by: Diksha Dhawan <40900030+ddhawan11@users.noreply.github.com> Co-authored-by: Isaac De Vlugt <isaacdevlugt@gmail.com> Co-authored-by: Diego <67476785+DSGuala@users.noreply.github.com> Co-authored-by: trbromley <brotho02@gmail.com> Co-authored-by: erick-xanadu <110487834+erick-xanadu@users.noreply.github.com> Co-authored-by: David Ittah <dime10@users.noreply.github.com> Co-authored-by: soranjh <40344468+soranjh@users.noreply.github.com>
1 parent fbc2a39 commit 59a1e05

File tree

7 files changed

+244
-17
lines changed

7 files changed

+244
-17
lines changed

doc/releases/changelog-dev.md

+4
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,9 @@
156156

157157
<h3>Bug fixes 🐛</h3>
158158

159+
* The `dynamic_one_shot` transform now has expanded support for the `jax` and `torch` interfaces.
160+
[(#5672)](https://github.com/PennyLaneAI/pennylane/pull/5672)
161+
159162
* The decomposition of `StronglyEntanglingLayers` is now compatible with broadcasting.
160163
[(#5716)](https://github.com/PennyLaneAI/pennylane/pull/5716)
161164

@@ -213,5 +216,6 @@ Korbinian Kottmann,
213216
Christina Lee,
214217
Vincent Michaud-Rioux,
215218
Lee James O'Riordan,
219+
Mudit Pandey,
216220
Kenya Sakka,
217221
David Wierichs.

pennylane/devices/qubit/apply_operation.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -330,8 +330,10 @@ def binomial_fn(n, p):
330330
# to reset enables jax.jit and prevents it from using Python callbacks
331331
element = op.reset and sample == 1
332332
matrix = qml.math.array(
333-
[[(element + 1) % 2, (element) % 2], [(element) % 2, (element + 1) % 2]], like=interface
334-
).astype(float)
333+
[[(element + 1) % 2, (element) % 2], [(element) % 2, (element + 1) % 2]],
334+
like=interface,
335+
dtype=float,
336+
)
335337
state = apply_operation(
336338
qml.QubitUnitary(matrix, wire), state, is_state_batched=is_state_batched, debugger=debugger
337339
)

pennylane/devices/qubit/simulate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def simulate(
287287
trainable_params=circuit.trainable_params,
288288
)
289289
keys = jax_random_split(prng_key, num=circuit.shots.total_shots)
290-
if qml.math.get_deep_interface(circuit.data) == "jax":
290+
if qml.math.get_deep_interface(circuit.data) == "jax" and prng_key is not None:
291291
# pylint: disable=import-outside-toplevel
292292
import jax
293293

pennylane/math/single_dispatch.py

+1
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ def _take_autograd(tensor, indices, axis=None):
242242
ar.autoray._SUBMODULE_ALIASES["tensorflow", "isclose"] = "tensorflow.experimental.numpy"
243243
ar.autoray._SUBMODULE_ALIASES["tensorflow", "atleast_1d"] = "tensorflow.experimental.numpy"
244244
ar.autoray._SUBMODULE_ALIASES["tensorflow", "all"] = "tensorflow.experimental.numpy"
245+
ar.autoray._SUBMODULE_ALIASES["tensorflow", "ravel"] = "tensorflow.experimental.numpy"
245246
ar.autoray._SUBMODULE_ALIASES["tensorflow", "vstack"] = "tensorflow.experimental.numpy"
246247

247248
tf_fft_functions = [

pennylane/transforms/dynamic_one_shot.py

+19-9
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def measurement_with_no_shots(measurement):
228228
)
229229

230230
interface = qml.math.get_deep_interface(circuit.data)
231+
interface = "numpy" if interface == "builtins" else interface
231232

232233
all_mcms = [op for op in aux_tapes[0].operations if is_mcm(op)]
233234
n_mcms = len(all_mcms)
@@ -243,10 +244,13 @@ def measurement_with_no_shots(measurement):
243244
mcm_samples = qml.math.array(
244245
[[res] if single_measurement else res[-n_mcms::] for res in results], like=interface
245246
)
246-
has_postselect = qml.math.array([op.postselect is not None for op in all_mcms]).reshape((1, -1))
247+
# Can't use boolean dtype array with tf, hence why conditionally setting items to 0 or 1
248+
has_postselect = qml.math.array(
249+
[[int(op.postselect is not None) for op in all_mcms]], like=interface
250+
)
247251
postselect = qml.math.array(
248-
[0 if op.postselect is None else op.postselect for op in all_mcms]
249-
).reshape((1, -1))
252+
[[0 if op.postselect is None else op.postselect for op in all_mcms]], like=interface
253+
)
250254
is_valid = qml.math.all(mcm_samples * has_postselect == postselect, axis=1)
251255
has_valid = qml.math.any(is_valid)
252256
mid_meas = [op for op in circuit.operations if is_mcm(op)]
@@ -268,7 +272,12 @@ def measurement_with_no_shots(measurement):
268272
meas = measurement_with_no_shots(m)
269273
m_count += 1
270274
else:
271-
result = qml.math.array([res[m_count] for res in results], like=interface)
275+
result = [res[m_count] for res in results]
276+
if not isinstance(m, CountsMP):
277+
# We don't need to cast to arrays when using qml.counts. qml.math.array is not viable
278+
# as it assumes all elements of the input are of builtin python types and not belonging
279+
# to any particular interface
280+
result = qml.math.stack(result, like=interface)
272281
meas = gather_non_mcm(m, result, is_valid)
273282
m_count += 1
274283
if isinstance(m, SampleMP):
@@ -292,7 +301,9 @@ def gather_non_mcm(circuit_measurement, measurement, is_valid):
292301
if isinstance(circuit_measurement, CountsMP):
293302
tmp = Counter()
294303
for i, d in enumerate(measurement):
295-
tmp.update(dict((k, v * is_valid[i]) for k, v in d.items()))
304+
tmp.update(
305+
dict((k if isinstance(k, str) else float(k), v * is_valid[i]) for k, v in d.items())
306+
)
296307
tmp = Counter({k: v for k, v in tmp.items() if v > 0})
297308
return dict(sorted(tmp.items()))
298309
if isinstance(circuit_measurement, ExpectationMP):
@@ -341,14 +352,13 @@ def gather_mcm(measurement, samples, is_valid):
341352
counts = qml.math.array(counts, like=interface)
342353
return counts / qml.math.sum(counts)
343354
if isinstance(measurement, CountsMP):
344-
mcm_samples = [{"".join(str(v) for v in tuple(s)): 1} for s in mcm_samples]
355+
mcm_samples = [{"".join(str(int(v)) for v in tuple(s)): 1} for s in mcm_samples]
345356
return gather_non_mcm(measurement, mcm_samples, is_valid)
357+
mcm_samples = qml.math.ravel(qml.math.array(mv.concretize(samples), like=interface))
346358
if isinstance(measurement, ProbabilityMP):
347-
mcm_samples = qml.math.array(mv.concretize(samples), like=interface).ravel()
348359
counts = [qml.math.sum((mcm_samples == v) * is_valid) for v in list(mv.branches.values())]
349360
counts = qml.math.array(counts, like=interface)
350361
return counts / qml.math.sum(counts)
351-
mcm_samples = qml.math.array([mv.concretize(samples)], like=interface).ravel()
352362
if isinstance(measurement, CountsMP):
353-
mcm_samples = [{s: 1} for s in mcm_samples]
363+
mcm_samples = [{float(s): 1} for s in mcm_samples]
354364
return gather_non_mcm(measurement, mcm_samples, is_valid)

tests/devices/default_qubit/test_default_qubit_native_mcm.py

+50-5
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Tests for default qubit preprocessing."""
15-
from functools import partial, reduce
15+
from functools import reduce
1616
from typing import Iterable, Sequence
1717

1818
import numpy as np
@@ -24,7 +24,11 @@
2424

2525
pytestmark = pytest.mark.slow
2626

27-
get_device = partial(qml.device, name="default.qubit", seed=8237945)
27+
28+
def get_device(**kwargs):
29+
kwargs.setdefault("shots", None)
30+
kwargs.setdefault("seed", 8237945)
31+
return qml.device("default.qubit", **kwargs)
2832

2933

3034
def validate_counts(shots, results1, results2, batch_size=None):
@@ -88,7 +92,7 @@ def validate_samples(shots, results1, results2, batch_size=None):
8892
assert results1.ndim == results2.ndim
8993
if results2.ndim > 1:
9094
assert results1.shape[1] == results2.shape[1]
91-
np.allclose(np.sum(results1), np.sum(results2), atol=20, rtol=0.2)
95+
np.allclose(qml.math.sum(results1), qml.math.sum(results2), atol=20, rtol=0.2)
9296

9397

9498
def validate_expval(shots, results1, results2, batch_size=None):
@@ -611,7 +615,7 @@ def test_sample_with_prng_key(shots, postselect, reset):
611615
# pylint: disable=import-outside-toplevel
612616
from jax.random import PRNGKey
613617

614-
dev = qml.device("default.qubit", shots=shots, seed=PRNGKey(678))
618+
dev = get_device(shots=shots, seed=PRNGKey(678))
615619
param = [np.pi / 4, np.pi / 3]
616620
obs = qml.PauliZ(0) @ qml.PauliZ(1)
617621

@@ -659,7 +663,7 @@ def test_jax_jit(diff_method, postselect, reset):
659663

660664
shots = 10
661665

662-
dev = qml.device("default.qubit", shots=shots, seed=jax.random.PRNGKey(678))
666+
dev = get_device(shots=shots, seed=jax.random.PRNGKey(678))
663667
params = [np.pi / 2.5, np.pi / 3, -np.pi / 3.5]
664668
obs = qml.PauliY(0)
665669

@@ -750,3 +754,44 @@ def func(x):
750754
results2 = func2(param)
751755
for r1, r2 in zip(results1.keys(), results2.keys()):
752756
assert r1 == r2
757+
758+
759+
@pytest.mark.torch
760+
@pytest.mark.parametrize("postselect", [None, 1])
761+
@pytest.mark.parametrize("diff_method", [None, "best"])
762+
@pytest.mark.parametrize("measure_f", [qml.probs, qml.sample, qml.expval, qml.var])
763+
@pytest.mark.parametrize("meas_obj", [qml.PauliZ(1), [0, 1], "composite_mcm", "mcm_list"])
764+
def test_torch_integration(postselect, diff_method, measure_f, meas_obj):
765+
"""Test that native MCM circuits are executed correctly with Torch"""
766+
if measure_f in (qml.var, qml.expval) and (
767+
isinstance(meas_obj, list) or meas_obj == "mcm_list"
768+
):
769+
pytest.skip("Can't use wires/mcm lists with var or expval")
770+
771+
import torch
772+
773+
shots = 7000
774+
dev = get_device(shots=shots, seed=123456789)
775+
param = torch.tensor(np.pi / 3, dtype=torch.float64)
776+
777+
@qml.qnode(dev, diff_method=diff_method)
778+
def func(x):
779+
qml.RX(x, 0)
780+
m0 = qml.measure(0)
781+
qml.RX(0.5 * x, 1)
782+
m1 = qml.measure(1, postselect=postselect)
783+
qml.cond((m0 + m1) == 2, qml.RY)(2.0 * x, 0)
784+
m2 = qml.measure(0)
785+
786+
mid_measure = 0.5 * m2 if meas_obj == "composite_mcm" else [m1, m2]
787+
measurement_key = "wires" if isinstance(meas_obj, list) else "op"
788+
measurement_value = mid_measure if isinstance(meas_obj, str) else meas_obj
789+
return measure_f(**{measurement_key: measurement_value})
790+
791+
func1 = func
792+
func2 = qml.defer_measurements(func)
793+
794+
results1 = func1(param)
795+
results2 = func2(param)
796+
797+
validate_measurements(measure_f, shots, results1, results2)

tests/transforms/test_dynamic_one_shot.py

+165
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,168 @@ def test_len_measurements_mcms(measure, aux_measure, n_meas):
155155
assert len(aux_tape.measurements) == n_meas + n_mcms
156156
assert isinstance(aux_tape.measurements[0], aux_measure)
157157
assert all(isinstance(m, SampleMP) for m in aux_tape.measurements[1:])
158+
159+
160+
def assert_results(res, shots, n_mcms):
161+
"""Helper to check that expected raw results of executing the transformed tape are correct"""
162+
assert len(res) == shots
163+
# One for the non-MeasurementValue MP, and the rest of the mid-circuit measurements
164+
assert all(len(r) == n_mcms + 1 for r in res)
165+
# Not validating distribution of results as device sampling unit tests already validate
166+
# that samples are generated correctly.
167+
168+
169+
@pytest.mark.jax
170+
@pytest.mark.parametrize("measure_f", (qml.expval, qml.probs, qml.sample, qml.var))
171+
@pytest.mark.parametrize("shots", [20, [20, 21]])
172+
@pytest.mark.parametrize("n_mcms", [1, 3])
173+
def test_tape_results_jax(shots, n_mcms, measure_f):
174+
"""Test that the simulation results of a tape are correct with jax parameters"""
175+
import jax
176+
177+
dev = qml.device("default.qubit", wires=4, shots=shots, seed=jax.random.PRNGKey(123))
178+
param = jax.numpy.array(np.pi / 2)
179+
180+
mv = qml.measure(0)
181+
mp = mv.measurements[0]
182+
183+
tape = qml.tape.QuantumScript(
184+
[qml.RX(param, 0), mp] + [MidMeasureMP(0, id=str(i)) for i in range(n_mcms - 1)],
185+
[measure_f(op=qml.PauliZ(0)), measure_f(op=mv)],
186+
shots=shots,
187+
)
188+
189+
tapes, _ = qml.dynamic_one_shot(tape)
190+
results = dev.execute(tapes)[0]
191+
192+
# The transformed tape never has a shot vector
193+
if isinstance(shots, list):
194+
shots = sum(shots)
195+
196+
assert_results(results, shots, n_mcms)
197+
198+
199+
@pytest.mark.jax
200+
@pytest.mark.parametrize(
201+
"measure_f, expected1, expected2",
202+
[
203+
(qml.expval, 1.0, 1.0),
204+
(qml.probs, [1, 0], [0, 1]),
205+
(qml.sample, 1, 1),
206+
(qml.var, 0.0, 0.0),
207+
],
208+
)
209+
@pytest.mark.parametrize("shots", [20, [20, 21]])
210+
@pytest.mark.parametrize("n_mcms", [1, 3])
211+
def test_jax_results_processing(shots, n_mcms, measure_f, expected1, expected2):
212+
"""Test that the results of tapes are processed correctly for tapes with jax parameters"""
213+
import jax.numpy as jnp
214+
215+
mv = qml.measure(0)
216+
mp = mv.measurements[0]
217+
218+
tape = qml.tape.QuantumScript(
219+
[qml.RX(1.5, 0), mp] + [MidMeasureMP(0)] * (n_mcms - 1),
220+
[measure_f(op=qml.PauliZ(0)), measure_f(op=mv)],
221+
shots=shots,
222+
)
223+
_, fn = qml.dynamic_one_shot(tape)
224+
all_shots = sum(shots) if isinstance(shots, list) else shots
225+
226+
first_res = jnp.array([1.0, 0.0]) if measure_f == qml.probs else jnp.array(1.0)
227+
rest = jnp.array(1, dtype=int)
228+
single_shot_res = (first_res,) + (rest,) * n_mcms
229+
# Raw results for each shot are (sample_for_first_measurement,) + (sample for 1st MCM, sample for 2nd MCM, ...)
230+
raw_results = (single_shot_res,) * all_shots
231+
raw_results = (raw_results,)
232+
res = fn(raw_results)
233+
234+
if measure_f is qml.sample:
235+
# All samples 1
236+
expected1 = (
237+
[[expected1] * s for s in shots] if isinstance(shots, list) else [expected1] * shots
238+
)
239+
expected2 = (
240+
[[expected2] * s for s in shots] if isinstance(shots, list) else [expected2] * shots
241+
)
242+
else:
243+
expected1 = [expected1 for _ in shots] if isinstance(shots, list) else expected1
244+
expected2 = [expected2 for _ in shots] if isinstance(shots, list) else expected2
245+
246+
if isinstance(shots, list):
247+
assert len(res) == len(shots)
248+
for r, e1, e2 in zip(res, expected1, expected2):
249+
# Expected result is 2-list since we have two measurements in the tape
250+
assert qml.math.allclose(r, [e1, e2])
251+
else:
252+
# Expected result is 2-list since we have two measurements in the tape
253+
assert qml.math.allclose(res, [expected1, expected2])
254+
255+
256+
@pytest.mark.jax
257+
@pytest.mark.parametrize(
258+
"measure_f, expected1, expected2",
259+
[
260+
(qml.expval, 1.0, 1.0),
261+
(qml.probs, [1, 0], [0, 1]),
262+
(qml.sample, 1, 1),
263+
(qml.var, 0.0, 0.0),
264+
],
265+
)
266+
@pytest.mark.parametrize("shots", [20, [20, 22]])
267+
def test_jax_results_postselection_processing(shots, measure_f, expected1, expected2):
268+
"""Test that the results of tapes are processed correctly for tapes with jax parameters
269+
when postselecting"""
270+
import jax.numpy as jnp
271+
272+
param = jnp.array(np.pi / 2)
273+
fill_value = np.iinfo(np.int32).min
274+
mv = qml.measure(0, postselect=1)
275+
mp = mv.measurements[0]
276+
277+
tape = qml.tape.QuantumScript(
278+
[qml.RX(param, 0), mp, MidMeasureMP(0)],
279+
[measure_f(op=qml.PauliZ(0)), measure_f(op=mv)],
280+
shots=shots,
281+
)
282+
_, fn = qml.dynamic_one_shot(tape)
283+
all_shots = sum(shots) if isinstance(shots, list) else shots
284+
285+
# Alternating tuple. Only the values at odd indices are valid
286+
first_res_two_shot = (
287+
(jnp.array([1.0, 0.0]), jnp.array([0.0, 1.0]))
288+
if measure_f == qml.probs
289+
else (jnp.array(1.0), jnp.array(0.0))
290+
)
291+
first_res = first_res_two_shot * (all_shots // 2)
292+
# Tuple of alternating 1s and 0s. Zero is invalid as postselecting on 1
293+
postselect_res = (jnp.array(1, dtype=int), jnp.array(0, dtype=int)) * (all_shots // 2)
294+
rest = (jnp.array(1, dtype=int),) * all_shots
295+
# Raw results for each shot are (sample_for_first_measurement, sample for 1st MCM, sample for 2nd MCM)
296+
raw_results = tuple(zip(first_res, postselect_res, rest))
297+
raw_results = (raw_results,)
298+
res = fn(raw_results)
299+
300+
if measure_f is qml.sample:
301+
expected1 = (
302+
[[expected1, fill_value] * (s // 2) for s in shots]
303+
if isinstance(shots, list)
304+
else [expected1, fill_value] * (shots // 2)
305+
)
306+
expected2 = (
307+
[[expected2, fill_value] * (s // 2) for s in shots]
308+
if isinstance(shots, list)
309+
else [expected2, fill_value] * (shots // 2)
310+
)
311+
else:
312+
expected1 = [expected1 for _ in shots] if isinstance(shots, list) else expected1
313+
expected2 = [expected2 for _ in shots] if isinstance(shots, list) else expected2
314+
315+
if isinstance(shots, list):
316+
assert len(res) == len(shots)
317+
for r, e1, e2 in zip(res, expected1, expected2):
318+
# Expected result is 2-list since we have two measurements in the tape
319+
assert qml.math.allclose(r, [e1, e2])
320+
else:
321+
# Expected result is 2-list since we have two measurements in the tape
322+
assert qml.math.allclose(res, [expected1, expected2])

0 commit comments

Comments
 (0)