Skip to content

Commit a287ecb

Browse files
authored
Add execution_span.TwirledSliceSpan (#2011)
* Add TwirledSliceSpan * reno * update docstring * try and fix lint * fix bullets * grrr bullets * fix tests
1 parent 3a69811 commit a287ecb

File tree

7 files changed

+249
-7
lines changed

7 files changed

+249
-7
lines changed

qiskit_ibm_runtime/execution_span/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@
3535
ExecutionSpans
3636
ShapeType
3737
SliceSpan
38+
TwirledSliceSpan
3839
"""
3940

4041
from .double_slice_span import DoubleSliceSpan
4142
from .execution_span import ExecutionSpan, ShapeType
4243
from .execution_spans import ExecutionSpans
4344
from .slice_span import SliceSpan
45+
from .twirled_slice_span import TwirledSliceSpan

qiskit_ibm_runtime/execution_span/double_slice_span.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,16 @@ class DoubleSliceSpan(ExecutionSpan):
2828
"""An :class:`~.ExecutionSpan` for data stored in a sliceable format.
2929
3030
This type of execution span references pub result data by assuming that it is a sliceable
31-
portion of the data where the shots are the outermost slice and the rest of the data is flattened.
32-
Therefore, for each pub dependent on this span, the constructor accepts two :class:`slice` objects,
33-
along with the corresponding shape of the data to be sliced; in contrast to
34-
:class:`~.SliceSpan`, this class does not assume that *all* shots for a particular set of parameter
35-
values are contiguous in the array of data.
31+
portion of the data where the shots are the outermost slice and the rest of the data is
32+
flattened. Therefore, for each pub dependent on this span, the constructor accepts two
33+
:class:`slice` objects, along with the corresponding shape of the data to be sliced; in contrast
34+
to :class:`~.SliceSpan`, this class does not assume that *all* shots for a particular set of
35+
parameter values are contiguous in the array of data.
3636
3737
Args:
3838
start: The start time of the span, in UTC.
3939
stop: The stop time of the span, in UTC.
40-
data_slices: A map from pub indices to ``(shape_tuple, slice, slice)``.
40+
data_slices: A map from pub indices to ``(shape_tuple, flat_shape_slice, shots_slice)``.
4141
"""
4242

4343
def __init__(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# This code is part of Qiskit.
2+
#
3+
# (C) Copyright IBM 2024.
4+
#
5+
# This code is licensed under the Apache License, Version 2.0. You may
6+
# obtain a copy of this license in the LICENSE.txt file in the root directory
7+
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
8+
#
9+
# Any modifications or derivative works of this code must retain this
10+
# copyright notice, and modified files need to carry a notice indicating
11+
# that they have been altered from the originals.
12+
13+
"""TwirledSliceSpan"""
14+
15+
from __future__ import annotations
16+
17+
from datetime import datetime
18+
from typing import Iterable
19+
20+
import math
21+
import numpy as np
22+
import numpy.typing as npt
23+
24+
from .execution_span import ExecutionSpan, ShapeType
25+
26+
27+
class TwirledSliceSpan(ExecutionSpan):
28+
"""An :class:`~.ExecutionSpan` for data stored in a sliceable format when twirling.
29+
30+
This type of execution span references pub result data that came from a twirled sampler
31+
experiment which was executed by either prepending or appending an axis to paramater values
32+
to account for twirling. Concretely, ``data_slices`` is a map from pub slices to tuples
33+
``(twirled_shape, at_front, shape_slice, shots_slice)`` where
34+
35+
* ``twirled_shape`` is the shape tuple including a twirling axis, and where the last
36+
axis is shots per randomization,
37+
* ``at_front`` is whether ``num_randomizations`` is at the front of the tuple, as
38+
opposed to right before the ``shots`` axis at the end,
39+
* ``shape_slice`` is a slice of an array of shape ``twirled_shape[:-1]``, flattened,
40+
* and ``shots_slice`` is a slice of ``twirled_shape[-1]``.
41+
42+
Args:
43+
start: The start time of the span, in UTC.
44+
stop: The stop time of the span, in UTC.
45+
data_slices: A map from pub indices to length-4 tuples described above.
46+
"""
47+
48+
def __init__(
49+
self,
50+
start: datetime,
51+
stop: datetime,
52+
data_slices: dict[int, tuple[ShapeType, bool, slice, slice]],
53+
):
54+
super().__init__(start, stop)
55+
self._data_slices = data_slices
56+
57+
def __eq__(self, other: object) -> bool:
58+
return isinstance(other, TwirledSliceSpan) and (
59+
self.start == other.start
60+
and self.stop == other.stop
61+
and self._data_slices == other._data_slices
62+
)
63+
64+
@property
65+
def pub_idxs(self) -> list[int]:
66+
return sorted(self._data_slices)
67+
68+
@property
69+
def size(self) -> int:
70+
size = 0
71+
for shape, _, shape_sl, shots_sl in self._data_slices.values():
72+
size += len(range(math.prod(shape[:-1]))[shape_sl]) * len(range(shape[-1])[shots_sl])
73+
return size
74+
75+
def mask(self, pub_idx: int) -> npt.NDArray[np.bool_]:
76+
twirled_shape, at_front, shape_sl, shots_sl = self._data_slices[pub_idx]
77+
mask = np.zeros(twirled_shape, dtype=np.bool_)
78+
mask.reshape((np.prod(twirled_shape[:-1]), twirled_shape[-1]))[(shape_sl, shots_sl)] = True
79+
80+
if at_front:
81+
# if the first axis is over twirling samples, push them right before shots
82+
ndim = len(twirled_shape)
83+
mask = mask.transpose((*range(1, ndim - 1), 0, ndim - 1))
84+
twirled_shape = twirled_shape[1:-1] + twirled_shape[:1] + twirled_shape[-1:]
85+
86+
# merge twirling axis and shots axis before returning
87+
return mask.reshape((*twirled_shape[:-2], math.prod(twirled_shape[-2:])))
88+
89+
def filter_by_pub(self, pub_idx: int | Iterable[int]) -> "TwirledSliceSpan":
90+
pub_idx = {pub_idx} if isinstance(pub_idx, int) else set(pub_idx)
91+
slices = {idx: val for idx, val in self._data_slices.items() if idx in pub_idx}
92+
return TwirledSliceSpan(self.start, self.stop, slices)

qiskit_ibm_runtime/utils/json.py

+18
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
DoubleSliceSpan,
8080
SliceSpan,
8181
ExecutionSpans,
82+
TwirledSliceSpan,
8283
)
8384

8485
from .noise_learner_result import NoiseLearnerResult
@@ -341,6 +342,16 @@ def default(self, obj: Any) -> Any: # pylint: disable=arguments-differ
341342
},
342343
}
343344
return {"__type__": "DoubleSliceSpan", "__value__": out_val}
345+
if isinstance(obj, TwirledSliceSpan):
346+
out_val = {
347+
"start": obj.start,
348+
"stop": obj.stop,
349+
"data_slices": {
350+
idx: (shape, at_front, arg_sl.start, arg_sl.stop, shot_sl.start, shot_sl.stop)
351+
for idx, (shape, at_front, arg_sl, shot_sl) in obj._data_slices.items()
352+
},
353+
}
354+
return {"__type__": "TwirledSliceSpan", "__value__": out_val}
344355
if isinstance(obj, SliceSpan):
345356
out_val = {
346357
"start": obj.start,
@@ -470,6 +481,13 @@ def object_hook(self, obj: Any) -> Any:
470481
for idx, (shape, arg0, arg1, shot0, shot1) in obj_val["data_slices"].items()
471482
}
472483
return DoubleSliceSpan(**obj_val)
484+
if obj_type == "TwirledSliceSpan":
485+
data_slices = obj_val["data_slices"]
486+
obj_val["data_slices"] = {
487+
int(idx): (tuple(shape), at_start, slice(arg0, arg1), slice(shot0, shot1))
488+
for idx, (shape, at_start, arg0, arg1, shot0, shot1) in data_slices.items()
489+
}
490+
return TwirledSliceSpan(**obj_val)
473491
if obj_type == "ExecutionSpan":
474492
new_slices = {
475493
int(idx): (tuple(shape), slice(*sl_args))
+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Added :class:`.TwirledSliceSpan`, an :class:`ExecutionSpan` to be used when
2+
twirling is enabled in the sampler. In particular, it keeps track of an extra shape
3+
axis corresponding to twirling randomizations, and also whether this axis exists at
4+
the front of the shape tuple, or right before the shots axis.

test/unit/test_data_serialization.py

+9
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
DoubleSliceSpan,
5252
SliceSpan,
5353
ExecutionSpans,
54+
TwirledSliceSpan,
5455
)
5556

5657
from .mock.fake_runtime_client import CustomResultRuntimeJob
@@ -468,6 +469,14 @@ def make_test_primitive_results(self):
468469
datetime(2024, 8, 21),
469470
{0: ((14,), slice(2, 3), slice(1, 9))},
470471
),
472+
TwirledSliceSpan(
473+
datetime(2024, 9, 20),
474+
datetime(2024, 3, 21),
475+
{
476+
0: ((14, 18, 21), True, slice(2, 3), slice(1, 9)),
477+
2: ((18, 14, 19), False, slice(2, 3), slice(1, 9)),
478+
},
479+
),
471480
]
472481
)
473482
}

test/unit/test_execution_span.py

+118-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@
1818

1919
import numpy as np
2020
import numpy.testing as npt
21-
from qiskit_ibm_runtime.execution_span import SliceSpan, DoubleSliceSpan, ExecutionSpans
21+
from qiskit_ibm_runtime.execution_span import (
22+
SliceSpan,
23+
DoubleSliceSpan,
24+
ExecutionSpans,
25+
TwirledSliceSpan,
26+
)
2227

2328
from ..ibm_test_case import IBMTestCase
2429

@@ -222,6 +227,118 @@ def test_filter_by_pub(self):
222227
)
223228

224229

230+
@ddt.ddt
231+
class TestTwirledSliceSpan(IBMTestCase):
232+
"""Class for testing TwirledSliceSpan."""
233+
234+
def setUp(self) -> None:
235+
super().setUp()
236+
self.start1 = datetime(2024, 10, 11, 4, 31, 30)
237+
self.stop1 = datetime(2024, 10, 11, 4, 31, 34)
238+
self.slices1 = {
239+
2: ((3, 1, 5), True, slice(1), slice(2, 4)),
240+
0: ((3, 5, 18, 10), False, slice(10, 13), slice(2, 5)),
241+
}
242+
self.span1 = TwirledSliceSpan(self.start1, self.stop1, self.slices1)
243+
244+
self.start2 = datetime(2024, 10, 16, 11, 9, 20)
245+
self.stop2 = datetime(2024, 10, 16, 11, 9, 30)
246+
self.slices2 = {
247+
0: ((7, 5, 100), True, slice(3, 5), slice(20, 40)),
248+
1: ((1, 5, 2, 3), False, slice(3, 9), slice(1, 3)),
249+
}
250+
self.span2 = TwirledSliceSpan(self.start2, self.stop2, self.slices2)
251+
252+
def test_limits(self):
253+
"""Test the start and stop properties"""
254+
self.assertEqual(self.span1.start, self.start1)
255+
self.assertEqual(self.span1.stop, self.stop1)
256+
self.assertEqual(self.span2.start, self.start2)
257+
self.assertEqual(self.span2.stop, self.stop2)
258+
259+
def test_equality(self):
260+
"""Test the equality method."""
261+
self.assertEqual(self.span1, self.span1)
262+
self.assertEqual(self.span1, TwirledSliceSpan(self.start1, self.stop1, self.slices1))
263+
self.assertNotEqual(self.span1, "aoeu")
264+
self.assertNotEqual(self.span1, self.span2)
265+
266+
def test_duration(self):
267+
"""Test the duration property"""
268+
self.assertEqual(self.span1.duration, 4)
269+
self.assertEqual(self.span2.duration, 10)
270+
271+
def test_repr(self):
272+
"""Test the repr method"""
273+
expect = "start='2024-10-11 04:31:30', stop='2024-10-11 04:31:34', size=11"
274+
self.assertEqual(repr(self.span1), f"TwirledSliceSpan(<{expect}>)")
275+
276+
def test_size(self):
277+
"""Test the size property"""
278+
self.assertEqual(self.span1.size, 1 * 2 + 3 * 3)
279+
self.assertEqual(self.span2.size, 2 * 20 + 6 * 2)
280+
281+
def test_pub_idxs(self):
282+
"""Test the pub_idxs property"""
283+
self.assertEqual(self.span1.pub_idxs, [0, 2])
284+
self.assertEqual(self.span2.pub_idxs, [0, 1])
285+
286+
def test_mask(self):
287+
"""Test the mask() method"""
288+
# reminder: ((3, 1, 5), True, slice(1), slice(2, 4))
289+
mask1 = np.zeros((3, 1, 5), dtype=bool)
290+
mask1.reshape((3, 5))[:1, 2:4] = True
291+
mask1 = mask1.transpose((1, 0, 2)).reshape((1, 15))
292+
npt.assert_array_equal(self.span1.mask(2), mask1)
293+
294+
# reminder: ((1, 5, 2, 3), False, slice(3,9), slice(1, 3)),
295+
mask2 = [
296+
[
297+
[[[0, 0, 0], [0, 0, 0]]],
298+
[[[0, 0, 0], [0, 1, 1]]],
299+
[[[0, 1, 1], [0, 1, 1]]],
300+
[[[0, 1, 1], [0, 1, 1]]],
301+
[[[0, 1, 1], [0, 0, 0]]],
302+
]
303+
]
304+
mask2 = np.array(mask2, dtype=bool).reshape((1, 5, 6))
305+
npt.assert_array_equal(self.span2.mask(1), mask2)
306+
307+
@ddt.data(
308+
(0, True, True),
309+
([0, 1], True, True),
310+
([0, 1, 2], True, True),
311+
([1, 2], True, True),
312+
([1], False, True),
313+
(2, True, False),
314+
([0, 2], True, True),
315+
)
316+
@ddt.unpack
317+
def test_contains_pub(self, idx, span1_expected_res, span2_expected_res):
318+
"""The the contains_pub method"""
319+
self.assertEqual(self.span1.contains_pub(idx), span1_expected_res)
320+
self.assertEqual(self.span2.contains_pub(idx), span2_expected_res)
321+
322+
def test_filter_by_pub(self):
323+
"""The the filter_by_pub method"""
324+
self.assertEqual(
325+
self.span1.filter_by_pub([]), TwirledSliceSpan(self.start1, self.stop1, {})
326+
)
327+
self.assertEqual(
328+
self.span2.filter_by_pub([]), TwirledSliceSpan(self.start2, self.stop2, {})
329+
)
330+
331+
self.assertEqual(
332+
self.span1.filter_by_pub([1, 0]),
333+
TwirledSliceSpan(self.start1, self.stop1, {0: self.slices1[0]}),
334+
)
335+
336+
self.assertEqual(
337+
self.span1.filter_by_pub(2),
338+
TwirledSliceSpan(self.start1, self.stop1, {2: self.slices1[2]}),
339+
)
340+
341+
225342
@ddt.ddt
226343
class TestExecutionSpans(IBMTestCase):
227344
"""Class for testing ExecutionSpans."""

0 commit comments

Comments
 (0)