forked from qiskit-community/qiskit-experiments
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcomposite_curve_analysis.py
450 lines (386 loc) · 17 KB
/
composite_curve_analysis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
# This code is part of Qiskit.
#
# (C) Copyright IBM 2021.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""
Analysis class for multi-group curve fitting.
"""
# pylint: disable=invalid-name
import warnings
from typing import Dict, List, Optional, Tuple, Union
import lmfit
import numpy as np
import pandas as pd
from uncertainties import unumpy as unp
from qiskit.utils.deprecation import deprecate_func
from qiskit_experiments.framework import (
AnalysisResultData,
BaseAnalysis,
ExperimentData,
Options,
)
from qiskit_experiments.visualization import (
BaseDrawer,
BasePlotter,
CurvePlotter,
LegacyCurveCompatDrawer,
MplDrawer,
)
from .base_curve_analysis import PARAMS_ENTRY_PREFIX, BaseCurveAnalysis
from .curve_data import CurveFitResult
from .scatter_table import ScatterTable
from .utils import eval_with_uncertainties
class CompositeCurveAnalysis(BaseAnalysis):
r"""Composite Curve Analysis.
The :class:`.CompositeCurveAnalysis` takes multiple curve analysis instances
and performs each analysis on the same experimental results.
These analyses are performed independently, thus fit parameters have no correlation.
Note that this is different from :class:`.CompositeAnalysis` which
analyses the outcome of a composite experiment, in which multiple different
experiments are performed.
The :class:`.CompositeCurveAnalysis` is attached to a single experiment instance,
which may execute similar circuits with slightly different settings.
Experiments with different settings might be distinguished by the circuit
metadata. The outcomes of the same set of experiments are assigned to a
specific analysis instance in the composite curve analysis.
This mapping is usually done with the analysis option ``filter_data`` dictionary.
Otherwise, all analyses are performed on the same set of outcomes.
Examples:
In this example, we write up a composite analysis consisting of two oscillation
analysis instances, assuming two Rabi experiments in 1-2 subspace
starting with different initial states :math:`\in \{|0\rangle, |1\rangle\}`.
This is a typical procedure to measure the thermal population of the qubit.
.. code-block:: python
from qiskit_experiments import curve_analysis as curve
analyses = []
for qi in (0, 1):
analysis = curve.OscillationAnalysis(name=f"init{qi}")
analysis.set_options(
return_fit_parameters=["freq"],
filter_data={"init_state": qi},
)
analysis = CompositeCurveAnalysis(analyses=analyses)
This ``analysis`` will return two analysis result data for the fit parameter "freq"
for experiments with the initial state :math:`|0\rangle` and :math:`|1\rangle`.
The experimental circuits starting with different initial states must be
distinguished by the circuit metadata ``{"init_state": 0}`` or ``{"init_state": 1}``,
along with the "xval" in the same dictionary.
If you want to compute another quantity using two fitting outcomes, you can
override :meth:`CompositeCurveAnalysis._create_curve_data` in subclass.
:class:`.CompositeCurveAnalysis` subclass may override following methods.
.. rubric:: _evaluate_quality
This method evaluates the quality of the composite fit based on
the all analysis outcomes.
This returns "good" when all fit outcomes are evaluated as "good",
otherwise it returns "bad".
.. rubric:: _create_analysis_results
This method is passed all the group fit outcomes and can return a list of
new values to be stored in the analysis results.
.. rubric:: _create_figures
This method creates figures by consuming the scatter table data.
Figures are created when the analysis option ``plot`` is ``True``.
"""
def __init__(
self,
analyses: List[BaseCurveAnalysis],
name: Optional[str] = None,
):
super().__init__()
self._analyses = analyses
self._name = name or self.__class__.__name__
@property
def parameters(self) -> List[str]:
"""Return parameters of this curve analysis."""
unite_params = []
for analysis in self._analyses:
# Respect ordering of parameters
for name in analysis.parameters:
if name not in unite_params:
unite_params.append(name)
return unite_params
@property
def name(self) -> str:
"""Return name of this analysis."""
return self._name
@property
def models(self) -> Dict[str, List[lmfit.Model]]:
"""Return fit models."""
models = {}
for analysis in self._analyses:
models[analysis.name] = analysis.models
return models
@property
def plotter(self) -> BasePlotter:
"""A short-cut to the plotter instance."""
return self._options.plotter
@property
@deprecate_func(
since="0.5",
additional_msg="Use `plotter` from the new visualization module instead.",
removal_timeline="after 0.6",
package_name="qiskit-experiments",
)
def drawer(self) -> BaseDrawer:
"""A short-cut for curve drawer instance, if set. ``None`` otherwise."""
if hasattr(self._options, "curve_drawer"):
return self._options.curve_drawer
else:
return None
def analyses(
self, index: Optional[Union[str, int]] = None
) -> Union[BaseCurveAnalysis, List[BaseCurveAnalysis]]:
"""Return curve analysis instance.
Args:
index: Name of group or numerical index.
Returns:
Curve analysis instance.
"""
if index is None:
return self._analyses
if isinstance(index, str):
group_names = [analysis.name for analysis in self._analyses]
num_index = group_names.index(index)
return self._analyses[num_index]
return self._analyses[index]
def _evaluate_quality(
self,
fit_data: Dict[str, CurveFitResult],
) -> Union[str, None]:
"""Evaluate quality of the fit result.
Args:
fit_data: Fit outcome keyed on the analysis name.
Returns:
String that represents fit result quality. Usually "good" or "bad".
"""
for analysis in self._analyses:
if analysis._evaluate_quality(fit_data[analysis.name]) != "good":
return "bad"
return "good"
# pylint: disable=unused-argument
def _create_analysis_results(
self,
fit_data: Dict[str, CurveFitResult],
quality: str,
**metadata,
) -> List[AnalysisResultData]:
"""Create analysis results based on all analysis outcomes.
Args:
fit_data: Fit outcome keyed on the analysis name.
quality: Quality of fit outcome.
Returns:
List of analysis result data.
"""
return []
def _create_figures(
self,
curve_data: ScatterTable,
) -> List["matplotlib.figure.Figure"]:
"""Create a list of figures from the curve data.
Args:
curve_data: Scatter data table containing all data points.
Returns:
A list of figures.
"""
for analysis in self.analyses():
group_data = curve_data.filter(analysis=analysis.name)
model_names = analysis.model_names()
for series_id, sub_data in group_data.iter_by_series_id():
full_name = f"{model_names[series_id]}_{analysis.name}"
# Plot raw data scatters
if analysis.options.plot_raw_data:
raw_data = sub_data.filter(category="raw")
self.plotter.set_series_data(
series_name=full_name,
x=raw_data.x,
y=raw_data.y,
)
# Plot formatted data scatters
formatted_data = sub_data.filter(category=analysis.options.fit_category)
self.plotter.set_series_data(
series_name=full_name,
x_formatted=formatted_data.x,
y_formatted=formatted_data.y,
y_formatted_err=formatted_data.y_err,
)
# Plot fit lines
line_data = sub_data.filter(category="fitted")
if len(line_data) == 0:
continue
fit_stdev = line_data.y_err
self.plotter.set_series_data(
series_name=full_name,
x_interp=line_data.x,
y_interp=line_data.y,
y_interp_err=fit_stdev if np.isfinite(fit_stdev).all() else None,
)
return [self.plotter.figure()]
@classmethod
def _default_options(cls) -> Options:
"""Default analysis options.
Analysis Options:
plotter (BasePlotter): A plotter instance to visualize
the analysis result.
plot (bool): Set ``True`` to create figure for fit result.
This is ``True`` by default.
return_fit_parameters (bool): Set ``True`` to return all fit model parameters
with details of the fit outcome. Default to ``True``.
return_data_points (bool): Set ``True`` to include in the analysis result
the formatted data points given to the fitter. Default to ``False``.
extra (Dict[str, Any]): A dictionary that is appended to all database entries
as extra information.
"""
options = super()._default_options()
options.update_options(
plotter=CurvePlotter(MplDrawer()),
plot=True,
return_fit_parameters=True,
return_data_points=False,
extra={},
)
# Set automatic validator for particular option values
options.set_validator(field="plotter", validator_value=BasePlotter)
return options
def set_options(self, **fields):
# TODO remove this in Qiskit Experiments 0.6
if "curve_drawer" in fields:
warnings.warn(
"The option 'curve_drawer' is replaced with 'plotter'. "
"This option will be removed in Qiskit Experiments 0.6.",
DeprecationWarning,
stacklevel=2,
)
# Set the plotter drawer to `curve_drawer`. If `curve_drawer` is the right type, set it
# directly. If not, wrap it in a compatibility drawer.
if isinstance(fields["curve_drawer"], BaseDrawer):
plotter = self.options.plotter
plotter.drawer = fields.pop("curve_drawer")
fields["plotter"] = plotter
else:
drawer = fields["curve_drawer"]
compat_drawer = LegacyCurveCompatDrawer(drawer)
plotter = self.options.plotter
plotter.drawer = compat_drawer
fields["plotter"] = plotter
for field in fields:
if not hasattr(self.options, field):
warnings.warn(
f"Specified option {field} doesn't exist in this analysis instance. "
f"Note that {self.__class__.__name__} is a composite curve analysis instance, "
"which consists of multiple child curve analyses. "
"This options may exist in each analysis instance. "
"Please try setting options to child analyses through '.analyses()'.",
UserWarning,
)
super().set_options(**fields)
def _run_analysis(
self,
experiment_data: ExperimentData,
) -> Tuple[List[AnalysisResultData], List["matplotlib.figure.Figure"]]:
# Flag for plotting can be "always", "never", or "selective"
# the analysis option overrides self._generate_figures if set
if self.options.get("plot", None):
plot = "always"
elif self.options.get("plot", None) is False:
plot = "never"
else:
plot = getattr(self, "_generate_figures", "always")
analysis_results = []
figures = []
fit_dataset = {}
curve_data_set = []
for analysis in self._analyses:
analysis._initialize(experiment_data)
analysis.set_options(plot=False)
metadata = analysis.options.extra.copy()
metadata["group"] = analysis.name
table = analysis._format_data(analysis._run_data_processing(experiment_data.data()))
formatted_subset = table.filter(category=analysis.options.fit_category)
fit_data = analysis._run_curve_fit(formatted_subset)
fit_dataset[analysis.name] = fit_data
if fit_data.success:
quality = analysis._evaluate_quality(fit_data)
else:
quality = "bad"
if self.options.return_fit_parameters:
# Store fit status overview entry regardless of success.
# This is sometime useful when debugging the fitting code.
overview = AnalysisResultData(
name=PARAMS_ENTRY_PREFIX + analysis.name,
value=fit_data,
quality=quality,
extra=metadata,
)
analysis_results.append(overview)
if fit_data.success:
# Add fit data to curve data table
model_names = analysis.model_names()
for series_id, sub_data in formatted_subset.iter_by_series_id():
xval = sub_data.x
if len(xval) == 0:
# If data is empty, skip drawing this model.
# This is the case when fit model exist but no data to fit is provided.
continue
# Compute X, Y values with fit parameters.
xval_arr_fit = np.linspace(np.min(xval), np.max(xval), num=100, dtype=float)
uval_arr_fit = eval_with_uncertainties(
x=xval_arr_fit,
model=analysis.models[series_id],
params=fit_data.ufloat_params,
)
yval_arr_fit = unp.nominal_values(uval_arr_fit)
if fit_data.covar is not None:
yerr_arr_fit = unp.std_devs(uval_arr_fit)
else:
yerr_arr_fit = np.zeros_like(xval_arr_fit)
for xval, yval, yerr in zip(xval_arr_fit, yval_arr_fit, yerr_arr_fit):
table.add_row(
xval=xval,
yval=yval,
yerr=yerr,
series_name=model_names[series_id],
series_id=series_id,
category="fitted",
analysis=analysis.name,
)
analysis_results.extend(
analysis._create_analysis_results(
fit_data=fit_data,
quality=quality,
**metadata.copy(),
)
)
if self.options.return_data_points:
# Add raw data points
analysis_results.extend(
analysis._create_curve_data(curve_data=formatted_subset, **metadata)
)
curve_data_set.append(table)
combined_curve_data = ScatterTable.from_dataframe(
pd.concat([d.dataframe for d in curve_data_set])
)
total_quality = self._evaluate_quality(fit_dataset)
# After the quality is determined, plot can become a boolean flag for whether
# to generate the figure
plot_bool = plot == "always" or (plot == "selective" and total_quality == "bad")
# Create analysis results by combining all fit data
if all(fit_data.success for fit_data in fit_dataset.values()):
composite_results = self._create_analysis_results(
fit_data=fit_dataset, quality=total_quality, **self.options.extra.copy()
)
analysis_results.extend(composite_results)
else:
composite_results = []
if plot_bool:
self.plotter.set_supplementary_data(
fit_red_chi={k: v.reduced_chisq for k, v in fit_dataset.items() if v.success},
primary_results=composite_results,
)
figures.extend(self._create_figures(curve_data=combined_curve_data))
return analysis_results, figures