Skip to content

Commit fc53978

Browse files
committed
always call post processing
1 parent 30f847a commit fc53978

File tree

3 files changed

+67
-14
lines changed

3 files changed

+67
-14
lines changed

qiskit/providers/experiment/experiment_data.py

+33-13
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import contextlib
2424
from collections import deque
2525
from datetime import datetime
26-
import io
2726

2827
from qiskit.providers import Job, BaseJob, Backend, BaseBackend, Provider
2928
from qiskit.result import Result
@@ -34,8 +33,13 @@
3433
from .exceptions import ExperimentError, ExperimentEntryNotFound, ExperimentEntryExists
3534
from .analysis_result import AnalysisResultV1 as AnalysisResult
3635
from .json import NumpyEncoder, NumpyDecoder
37-
from .utils import (save_data, qiskit_version, plot_to_svg_bytes,
38-
ThreadSafeOrderedDict, ThreadSafeList)
36+
from .utils import (
37+
save_data,
38+
qiskit_version,
39+
plot_to_svg_bytes,
40+
ThreadSafeOrderedDict,
41+
ThreadSafeList,
42+
)
3943

4044
LOG = logging.getLogger(__name__)
4145

@@ -166,12 +170,16 @@ def add_data(
166170
data: Union[Result, List[Result], Job, List[Job], Dict, List[Dict]],
167171
post_processing_callback: Optional[Callable] = None,
168172
**kwargs: Any,
169-
):
173+
) -> None:
170174
"""Add experiment data.
171175
172176
Note:
173177
This method is not thread safe and should not be called by the
174-
`post_processing` function.
178+
`post_processing_callback` function.
179+
180+
Note:
181+
If `data` is a ``Job``, this method waits for the job to finish
182+
and calls the `post_processing_callback` function asynchronously.
175183
176184
Args:
177185
data: Experiment data to add.
@@ -183,18 +191,16 @@ def add_data(
183191
* List[Job]: Add data from the job results.
184192
* Dict: Add this data.
185193
* List[Dict]: Add this list of data.
186-
post_processing_callback: Callback function invoked when all pending
187-
jobs finish. This ``ExperimentData`` object is the only argument
194+
post_processing_callback: Callback function invoked when data is
195+
added. If `data` is a ``Job``, the callback is only invoked when
196+
the job finishes successfully.
197+
This ``ExperimentData`` object is the only argument
188198
to be passed to the callback function.
189199
**kwargs: Keyword arguments to be passed to the callback function.
190200
Raises:
191201
TypeError: If the input data type is invalid.
192202
"""
193-
if isinstance(data, dict):
194-
self._add_single_data(data)
195-
elif isinstance(data, Result):
196-
self._add_result_data(data)
197-
elif isinstance(data, (Job, BaseJob)):
203+
if isinstance(data, (Job, BaseJob)):
198204
if self.backend and self.backend != data.backend():
199205
LOG.warning(
200206
"Adding a job from a backend (%s) that is different "
@@ -219,13 +225,22 @@ def add_data(
219225
)
220226
if self.auto_save:
221227
self.save()
228+
return
229+
230+
if isinstance(data, dict):
231+
self._add_single_data(data)
232+
elif isinstance(data, Result):
233+
self._add_result_data(data)
222234
elif isinstance(data, list):
223235
# TODO use loop instead of recursion for fewer save()
224236
for dat in data:
225237
self.add_data(dat)
226238
else:
227239
raise TypeError(f"Invalid data type {type(data)}.")
228240

241+
if post_processing_callback is not None:
242+
post_processing_callback(self, **kwargs)
243+
229244
def _wait_for_job(
230245
self,
231246
job: Union[Job, BaseJob],
@@ -364,7 +379,10 @@ def add_figures(
364379
if isinstance(figure, str):
365380
fig_name = figure
366381
else:
367-
fig_name = f"figure_{self.experiment_id}_{datetime.now().isoformat()}"
382+
fig_name = (
383+
f"figure_{self.experiment_id[:8]}_"
384+
f"{datetime.now().isoformat()}_{len(self._figures)}"
385+
)
368386
else:
369387
fig_name = figure_names[idx]
370388

@@ -387,6 +405,7 @@ def add_figures(
387405
if save and service:
388406
if HAS_MATPLOTLIB:
389407
from matplotlib import pyplot
408+
390409
if isinstance(figure, pyplot.Figure):
391410
figure = plot_to_svg_bytes(figure)
392411
data = {
@@ -661,6 +680,7 @@ def save_all(self, service: Optional["ExperimentServiceV1"] = None) -> None:
661680
for name, figure in self._figures.items():
662681
if HAS_MATPLOTLIB:
663682
from matplotlib import pyplot
683+
664684
if isinstance(figure, pyplot.Figure):
665685
figure = plot_to_svg_bytes(figure)
666686
data = {"experiment_id": self.experiment_id, "figure": figure, "figure_name": name}

qiskit/providers/experiment/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def plot_to_svg_bytes(figure: "pyplot.Figure") -> bytes:
8383
Figure in bytes.
8484
"""
8585
buf = io.BytesIO()
86-
figure.savefig(buf, format='svg')
86+
figure.savefig(buf, format="svg")
8787
buf.seek(0)
8888
figure_data = buf.read()
8989
buf.close()

test/python/providers/test_experimentdata.py

+33
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,38 @@ def _callback(_exp_data):
147147
exp_data.block_for_results()
148148
self.assertTrue(called_back)
149149

150+
def test_add_data_callback(self):
151+
"""Test add data with callback."""
152+
153+
def _callback(_exp_data):
154+
self.assertIsInstance(_exp_data, ExperimentData)
155+
nonlocal called_back_count, expected_data, subtests
156+
expected_data.extend(subtests[called_back_count][1])
157+
self.assertEqual([dat["counts"] for dat in _exp_data.data()], expected_data)
158+
called_back_count += 1
159+
160+
a_result = self._get_job_result(1)
161+
results = [self._get_job_result(1), self._get_job_result(1)]
162+
a_dict = {"counts": {"01": 518}}
163+
dicts = [{"counts": {"00": 284}}, {"counts": {"00": 14}}]
164+
165+
subtests = [
166+
(a_result, [a_result.get_counts()]),
167+
(results, [res.get_counts() for res in results]),
168+
(a_dict, [a_dict["counts"]]),
169+
(dicts, [dat["counts"] for dat in dicts]),
170+
]
171+
172+
called_back_count = 0
173+
expected_data = []
174+
exp_data = ExperimentData(backend=self.backend, experiment_type="qiskit_test")
175+
176+
for data, _ in subtests:
177+
with self.subTest(data=data):
178+
exp_data.add_data(data, post_processing_callback=_callback)
179+
180+
self.assertEqual(len(subtests), called_back_count)
181+
150182
def test_add_data_job_callback_kwargs(self):
151183
"""Test add job data with callback and additional arguments."""
152184

@@ -206,6 +238,7 @@ def test_add_figure(self):
206238
def test_add_figure_plot(self):
207239
"""Test adding a matplotlib figure."""
208240
import matplotlib.pyplot as plt
241+
209242
figure, ax = plt.subplots()
210243
ax.plot([1, 2, 3])
211244

0 commit comments

Comments
 (0)