|
18 | 18 |
|
19 | 19 | import numpy as np
|
20 | 20 | import numpy.testing as npt
|
21 |
| -from qiskit_ibm_runtime.execution_span import SliceSpan, DoubleSliceSpan, ExecutionSpans |
| 21 | +from qiskit_ibm_runtime.execution_span import ( |
| 22 | + SliceSpan, |
| 23 | + DoubleSliceSpan, |
| 24 | + ExecutionSpans, |
| 25 | + TwirledSliceSpan, |
| 26 | +) |
22 | 27 |
|
23 | 28 | from ..ibm_test_case import IBMTestCase
|
24 | 29 |
|
@@ -222,6 +227,118 @@ def test_filter_by_pub(self):
|
222 | 227 | )
|
223 | 228 |
|
224 | 229 |
|
| 230 | +@ddt.ddt |
| 231 | +class TestTwirledSliceSpan(IBMTestCase): |
| 232 | + """Class for testing TwirledSliceSpan.""" |
| 233 | + |
| 234 | + def setUp(self) -> None: |
| 235 | + super().setUp() |
| 236 | + self.start1 = datetime(2024, 10, 11, 4, 31, 30) |
| 237 | + self.stop1 = datetime(2024, 10, 11, 4, 31, 34) |
| 238 | + self.slices1 = { |
| 239 | + 2: ((3, 1, 5), True, slice(1), slice(2, 4)), |
| 240 | + 0: ((3, 5, 18, 10), False, slice(10, 13), slice(2, 5)), |
| 241 | + } |
| 242 | + self.span1 = TwirledSliceSpan(self.start1, self.stop1, self.slices1) |
| 243 | + |
| 244 | + self.start2 = datetime(2024, 10, 16, 11, 9, 20) |
| 245 | + self.stop2 = datetime(2024, 10, 16, 11, 9, 30) |
| 246 | + self.slices2 = { |
| 247 | + 0: ((7, 5, 100), True, slice(3, 5), slice(20, 40)), |
| 248 | + 1: ((1, 5, 2, 3), False, slice(3, 9), slice(1, 3)), |
| 249 | + } |
| 250 | + self.span2 = TwirledSliceSpan(self.start2, self.stop2, self.slices2) |
| 251 | + |
| 252 | + def test_limits(self): |
| 253 | + """Test the start and stop properties""" |
| 254 | + self.assertEqual(self.span1.start, self.start1) |
| 255 | + self.assertEqual(self.span1.stop, self.stop1) |
| 256 | + self.assertEqual(self.span2.start, self.start2) |
| 257 | + self.assertEqual(self.span2.stop, self.stop2) |
| 258 | + |
| 259 | + def test_equality(self): |
| 260 | + """Test the equality method.""" |
| 261 | + self.assertEqual(self.span1, self.span1) |
| 262 | + self.assertEqual(self.span1, TwirledSliceSpan(self.start1, self.stop1, self.slices1)) |
| 263 | + self.assertNotEqual(self.span1, "aoeu") |
| 264 | + self.assertNotEqual(self.span1, self.span2) |
| 265 | + |
| 266 | + def test_duration(self): |
| 267 | + """Test the duration property""" |
| 268 | + self.assertEqual(self.span1.duration, 4) |
| 269 | + self.assertEqual(self.span2.duration, 10) |
| 270 | + |
| 271 | + def test_repr(self): |
| 272 | + """Test the repr method""" |
| 273 | + expect = "start='2024-10-11 04:31:30', stop='2024-10-11 04:31:34', size=11" |
| 274 | + self.assertEqual(repr(self.span1), f"TwirledSliceSpan(<{expect}>)") |
| 275 | + |
| 276 | + def test_size(self): |
| 277 | + """Test the size property""" |
| 278 | + self.assertEqual(self.span1.size, 1 * 2 + 3 * 3) |
| 279 | + self.assertEqual(self.span2.size, 2 * 20 + 6 * 2) |
| 280 | + |
| 281 | + def test_pub_idxs(self): |
| 282 | + """Test the pub_idxs property""" |
| 283 | + self.assertEqual(self.span1.pub_idxs, [0, 2]) |
| 284 | + self.assertEqual(self.span2.pub_idxs, [0, 1]) |
| 285 | + |
| 286 | + def test_mask(self): |
| 287 | + """Test the mask() method""" |
| 288 | + # reminder: ((3, 1, 5), True, slice(1), slice(2, 4)) |
| 289 | + mask1 = np.zeros((3, 1, 5), dtype=bool) |
| 290 | + mask1.reshape((3, 5))[:1, 2:4] = True |
| 291 | + mask1 = mask1.transpose((1, 0, 2)).reshape((1, 15)) |
| 292 | + npt.assert_array_equal(self.span1.mask(2), mask1) |
| 293 | + |
| 294 | + # reminder: ((1, 5, 2, 3), False, slice(3,9), slice(1, 3)), |
| 295 | + mask2 = [ |
| 296 | + [ |
| 297 | + [[[0, 0, 0], [0, 0, 0]]], |
| 298 | + [[[0, 0, 0], [0, 1, 1]]], |
| 299 | + [[[0, 1, 1], [0, 1, 1]]], |
| 300 | + [[[0, 1, 1], [0, 1, 1]]], |
| 301 | + [[[0, 1, 1], [0, 0, 0]]], |
| 302 | + ] |
| 303 | + ] |
| 304 | + mask2 = np.array(mask2, dtype=bool).reshape((1, 5, 6)) |
| 305 | + npt.assert_array_equal(self.span2.mask(1), mask2) |
| 306 | + |
| 307 | + @ddt.data( |
| 308 | + (0, True, True), |
| 309 | + ([0, 1], True, True), |
| 310 | + ([0, 1, 2], True, True), |
| 311 | + ([1, 2], True, True), |
| 312 | + ([1], False, True), |
| 313 | + (2, True, False), |
| 314 | + ([0, 2], True, True), |
| 315 | + ) |
| 316 | + @ddt.unpack |
| 317 | + def test_contains_pub(self, idx, span1_expected_res, span2_expected_res): |
| 318 | + """The the contains_pub method""" |
| 319 | + self.assertEqual(self.span1.contains_pub(idx), span1_expected_res) |
| 320 | + self.assertEqual(self.span2.contains_pub(idx), span2_expected_res) |
| 321 | + |
| 322 | + def test_filter_by_pub(self): |
| 323 | + """The the filter_by_pub method""" |
| 324 | + self.assertEqual( |
| 325 | + self.span1.filter_by_pub([]), TwirledSliceSpan(self.start1, self.stop1, {}) |
| 326 | + ) |
| 327 | + self.assertEqual( |
| 328 | + self.span2.filter_by_pub([]), TwirledSliceSpan(self.start2, self.stop2, {}) |
| 329 | + ) |
| 330 | + |
| 331 | + self.assertEqual( |
| 332 | + self.span1.filter_by_pub([1, 0]), |
| 333 | + TwirledSliceSpan(self.start1, self.stop1, {0: self.slices1[0]}), |
| 334 | + ) |
| 335 | + |
| 336 | + self.assertEqual( |
| 337 | + self.span1.filter_by_pub(2), |
| 338 | + TwirledSliceSpan(self.start1, self.stop1, {2: self.slices1[2]}), |
| 339 | + ) |
| 340 | + |
| 341 | + |
225 | 342 | @ddt.ddt
|
226 | 343 | class TestExecutionSpans(IBMTestCase):
|
227 | 344 | """Class for testing ExecutionSpans."""
|
|
0 commit comments