Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MultiStateDiscrimination output #1142

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _run_analysis(
for i in range(n_states):
counts = [0] * n_states
for point in predicted_data[i]:
counts[point] += 1
counts[int(point)] += 1
for j in range(n_states):
if j != i:
prob_wrong += counts[j] / num_shots
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def circuits(self) -> List[QuantumCircuit]:
)

# label the circuit
circuit.metadata = {"label": level}
circuit.metadata = {"label": str(level)}

circuit.measure_all()
circuits.append(circuit)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
fixes:
- |
Fixed a bug where :class:`.MultiStateDiscrimination` outputted integer state labels rather than the
string labels expected by the :class:`.DiscriminatorNode`.
37 changes: 33 additions & 4 deletions test/library/characterization/test_multi_state_discrimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@

"""Test the multi state discrimination experiments."""
from functools import wraps
from test.base import QiskitExperimentsTestCase
from test.data_processing import BaseDataProcessorTest
from unittest import SkipTest
import numpy as np

from ddt import ddt, data

Expand All @@ -22,6 +23,8 @@

from qiskit_experiments.library import MultiStateDiscrimination
from qiskit_experiments.test.pulse_backend import SingleTransmonTestBackend
from qiskit_experiments.data_processing import SkQDA
from qiskit_experiments.data_processing.nodes import DiscriminatorNode

from qiskit_experiments.warnings import HAS_SKLEARN

Expand All @@ -42,14 +45,14 @@ def wrapper(*args, **kwargs):


@ddt
class TestMultiStateDiscrimination(QiskitExperimentsTestCase):
class TestMultiStateDiscrimination(BaseDataProcessorTest):
"""Tests of the multi state discrimination experiment."""

def setUp(self):
"""Setup test variables."""
super().setUp()

self.backend = SingleTransmonTestBackend(noise=False)
self.backend = SingleTransmonTestBackend(noise=False, seed=0)

# Build x12 schedule
self.qubit = 0
Expand Down Expand Up @@ -83,7 +86,7 @@ def test_circuit_generation(self, n_states):
self.assertEqual(len(exp.circuits()), n_states)

# check the metadata
self.assertEqual(exp.circuits()[-1].metadata["label"], n_states - 1)
self.assertEqual(exp.circuits()[-1].metadata["label"], str(n_states - 1))

@data(2, 3)
@requires_sklearn
Expand All @@ -104,3 +107,29 @@ def test_discrimination_analysis(self, n_states):
"classes_"
]
self.assertEqual(len(discrim_lbls), n_states)

@requires_sklearn
def test_discriminator_data_processing(self):
"""Test that the discriminator experiment works with the discriminator node."""
discriminator = MultiStateDiscrimination([self.qubit], n_states=2, backend=self.backend)
discriminator_data = discriminator.run().block_for_results()
qda = SkQDA.from_config(discriminator_data.analysis_results("discriminator_config").value)
discriminatornode = DiscriminatorNode(discriminators=qda)

iq_data = [
[
[[0.8, -1.0], [0.1, 0.5], [-0.3, 0.4]],
[[-0.2, 0.4], [0.2, -1.0], [-0.5, 0.3]],
],
[
[[0, -1.0], [0.1, -0.5], [0.9, 0]],
[[-0.8, -0.5], [-0.1, 0.5], [0.2, 1.5]],
],
]

self.create_experiment_data(np.array(iq_data) * 1e16, single_shot=True)
fake_data = np.asarray([datum["memory"] for datum in self.iq_experiment.data()])
classified = discriminatornode(fake_data)
expected = [["110", "101"], ["000", "111"]]

self.assertListEqual(classified.tolist(), expected)