Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mermin refactor #1038

Merged
merged 35 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
5d1b73a
feat: Mermin acquisition function
andrea-pasquale Mar 24, 2024
58e7426
Merge branch 'main' into mermin
andrea-pasquale Sep 3, 2024
b05f7c5
fix: Add mermin pulses
andrea-pasquale Sep 4, 2024
59e4edb
first commit with generalization of mermin - angle and qubits
igres26 Oct 9, 2024
15c896e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 9, 2024
703a0ce
fixes mermin - not completed
igres26 Oct 28, 2024
3219031
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2024
7bba59b
refactor: remove mermin with circuits
Edoardo-Pedicillo Nov 5, 2024
e314e48
refactor: clean create_mermin_sequence(s)
Edoardo-Pedicillo Nov 5, 2024
959bb41
refactor: define proper data in MerminData
Edoardo-Pedicillo Nov 5, 2024
04377f9
refactor: update fit function
Edoardo-Pedicillo Nov 13, 2024
974de63
refactor: define STRING_TYPE
Edoardo-Pedicillo Nov 14, 2024
6ff492a
feat: define targets property
Edoardo-Pedicillo Nov 14, 2024
4511075
fix: typo in plot function
Edoardo-Pedicillo Nov 14, 2024
d1bc0da
refactor docs: update docs and remove circuit mermin
Edoardo-Pedicillo Nov 14, 2024
f2e9c34
Merge branch 'main' into mermin_refactor and fix tests
Edoardo-Pedicillo Nov 14, 2024
4c2e575
refactor: remove action_qq.yml
Edoardo-Pedicillo Nov 14, 2024
de1d327
fix: fix merging bugs
Edoardo-Pedicillo Nov 14, 2024
75d5097
fix: remove try-except and fix lint
Edoardo-Pedicillo Nov 14, 2024
358d15e
fix: rename variable
Edoardo-Pedicillo Nov 14, 2024
b7039ff
fix: evaluate outputs also when mitigation is False
Edoardo-Pedicillo Nov 14, 2024
2061ebb
Apply suggestions from code review
Edoardo-Pedicillo Nov 20, 2024
c8b54f2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 20, 2024
4d71606
fix: define PADDING constant
Edoardo-Pedicillo Nov 20, 2024
1c8d9e0
fix: change n_targets
Edoardo-Pedicillo Nov 20, 2024
87b4ebd
refactor: readout mitigation matrix acquisition and post-processing
Edoardo-Pedicillo Nov 20, 2024
6cd7e29
fix: propagate changes in the other routines and fix tests
Edoardo-Pedicillo Nov 20, 2024
f35f166
fix: remove pickle
Edoardo-Pedicillo Nov 20, 2024
b7a9249
fix: circuit execution issue
Edoardo-Pedicillo Nov 21, 2024
a5f5dd0
fix: use 2q circuits
Edoardo-Pedicillo Nov 21, 2024
4c0013a
fix: add check before dumping the mitigation matrix
Edoardo-Pedicillo Nov 21, 2024
d7f30da
Merge pull request #1047 from qiboteam/chsh
Edoardo-Pedicillo Nov 21, 2024
5204ba3
Merge pull request #1045 from qiboteam/readout_mitigation
Edoardo-Pedicillo Nov 25, 2024
804abff
fix: evaluate the state properly
Edoardo-Pedicillo Nov 25, 2024
407d537
fix: connect platform after circuit execution
Edoardo-Pedicillo Nov 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/qibocal/protocols/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
chsh_pulses,
correct_virtual_z_phases,
correct_virtual_z_phases_signal,
mermin,
optimize_two_qubit_gate,
)
from .two_qubit_state_tomography import two_qubit_state_tomography
Expand Down Expand Up @@ -149,5 +150,6 @@
"standard_rb_2q",
"standard_rb_2q_inter",
"optimize_two_qubit_gate",
"mermin",
"ramsey_zz",
]
1 change: 1 addition & 0 deletions src/qibocal/protocols/two_qubit_interaction/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .chevron import chevron, chevron_signal
from .chsh import chsh_circuits, chsh_pulses
from .mermin import mermin
from .optimize import optimize_two_qubit_gate
from .virtual_z_phases import correct_virtual_z_phases
from .virtual_z_phases_signal import correct_virtual_z_phases_signal
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .protocol import mermin
259 changes: 259 additions & 0 deletions src/qibocal/protocols/two_qubit_interaction/mermin/protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
from dataclasses import dataclass, field
from typing import Optional

import numpy as np
import numpy.typing as npt
import plotly.graph_objects as go
from qibolab import ExecutionParameters
from qibolab.platform import Platform
from qibolab.qubits import QubitId

from qibocal.auto.operation import Data, Parameters, Results, Routine

from ...readout_mitigation_matrix import readout_mitigation_matrix
from ...utils import STRING_TYPE, calculate_frequencies
from .pulses import create_mermin_sequences
from .utils import (
compute_mermin,
get_mermin_coefficients,
get_mermin_polynomial,
get_readout_basis,
)


@dataclass
class MerminParameters(Parameters):
"""Mermin experiment input parameters."""

ntheta: int
"""Number of angles probed linearly between 0 and 2 pi."""
native: Optional[bool] = False
"""If True a circuit will be created using only GPI2 and CZ gates."""
apply_error_mitigation: Optional[bool] = False
"""Error mitigation model"""


MerminType = np.dtype(
[
("theta", float),
("basis", STRING_TYPE),
("state", STRING_TYPE),
("frequency", int),
]
)


@dataclass
class MerminData(Data):
"""Mermin Data structure."""

thetas: list
"""Angles probed."""
data: dict[list[QubitId], npt.NDArray[MerminType]] = field(default_factory=dict)
"""Raw data acquired."""
mitigation_matrix: dict[list[QubitId], npt.NDArray[np.float64]] = field(
default_factory=dict
)
"""Mitigation matrix computed using the readout_mitigation_matrix protocol."""

@property
def targets(self):
return list(self.data.keys())


@dataclass
class MerminResults(Results):
"""Mermin Results class."""

mermin: dict[tuple[QubitId, ...], npt.NDArray[np.float64]] = field(
default_factory=dict
)
"""Raw Mermin value."""

mermin_mitigated: dict[tuple[QubitId, ...], npt.NDArray[np.float64]] = field(
default_factory=dict
)
"""Mitigated Mermin value."""


def _acquisition(
params: MerminParameters,
platform: Platform,
targets: list[list[QubitId]],
) -> MerminData:
r"""Data acquisition for Mermin protocol using pulse sequences."""

thetas = np.linspace(0, 2 * np.pi, params.ntheta)
data = MerminData(thetas=thetas.tolist())
if params.apply_error_mitigation:
mitigation_data, _ = readout_mitigation_matrix.acquisition(
readout_mitigation_matrix.parameters_type.load(
dict(pulses=True, nshots=params.nshots)
),
platform,
targets,
)

mitigation_results, _ = readout_mitigation_matrix.fit(mitigation_data)
data.mitigation_matrix = mitigation_results.readout_mitigation_matrix

for qubits in targets:
mermin_polynomial = get_mermin_polynomial(len(qubits))
readout_basis = get_readout_basis(mermin_polynomial)

for theta in thetas:
mermin_sequences = create_mermin_sequences(
platform, qubits, readout_basis=readout_basis, theta=theta
)
options = ExecutionParameters(nshots=params.nshots)
# TODO: use unrolling
for basis, sequence in mermin_sequences.items():
results = platform.execute_pulse_sequence(sequence, options=options)
frequencies = calculate_frequencies(results, qubits)
for state, frequency in enumerate(frequencies.values()):
data.register_qubit(
MerminType,
tuple(qubits),
dict(
theta=np.array([theta]),
basis=np.array([basis]),
state=np.array([str(format(state, f"0{len(qubits)}b"))]),
frequency=np.array([frequency]),
),
)
return data


def _fit(data: MerminData) -> MerminResults:
"""Fitting for Mermin protocol."""
targets = data.targets
results = {qubits: [] for qubits in targets}
mitigated_results = {qubits: [] for qubits in targets}
mermin_polynomial = get_mermin_polynomial(len(targets))
basis = np.unique(data.data[targets[0]].basis)
mermin_coefficients = get_mermin_coefficients(mermin_polynomial)
for qubits in targets:
for theta in data.thetas:
qubit_data = data.data[qubits]
outputs = []
mitigated_outputs = []
for base in basis:

data_filter = (qubit_data.basis == base) & (qubit_data.theta == theta)
state_freq = qubit_data[data_filter].frequency

outputs.append(
{
format(i, f"0{len(qubits)}b"): freq
for i, freq in enumerate(state_freq)
}
)

if data.mitigation_matrix:
mitigated_output = np.dot(
data.mitigation_matrix[qubits],
state_freq,
)
mitigated_outputs.append(
{
format(i, f"0{len(qubits)}b"): freq
for i, freq in enumerate(mitigated_output)
}
)
if data.mitigation_matrix:
mitigated_results[tuple(qubits)].append(
compute_mermin(mitigated_outputs, mermin_coefficients)
)
results[tuple(qubits)].append(compute_mermin(outputs, mermin_coefficients))
return MerminResults(
mermin=results,
mermin_mitigated=mitigated_results,
)


def _plot(data: MerminData, fit: MerminResults, target):
"""Plotting function for Mermin protocol."""
figures = []
targets = data.targets

n_targets = len(targets)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
n_targets = len(targets)
n_targets = len(targets[0])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so n_targets = len(target) would be better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It just has to be the number of qubits involved. Any way we get that it should work.

classical_bound = 2 ** (n_targets // 2)
quantum_bound = 2 ** ((n_targets - 1) / 2) * (2 ** (n_targets // 2))

fig = go.Figure(layout_yaxis_range=[-3, 3])
if fit is not None:
fig.add_trace(
go.Scatter(
x=data.thetas,
y=fit.mermin[tuple(target)],
name="Bare",
)
)
if fit.mermin_mitigated:
fig.add_trace(
go.Scatter(
x=data.thetas,
y=fit.mermin_mitigated[tuple(target)],
name="Mitigated",
)
)

fig.add_trace(
go.Scatter(
mode="lines",
x=data.thetas,
y=[+classical_bound] * len(data.thetas),
line_color="gray",
name="Classical limit",
line_dash="dash",
legendgroup="classic",
)
)

fig.add_trace(
go.Scatter(
mode="lines",
x=data.thetas,
y=[-classical_bound] * len(data.thetas),
line_color="gray",
name="Classical limit",
legendgroup="classic",
line_dash="dash",
showlegend=False,
)
)

fig.add_trace(
go.Scatter(
mode="lines",
x=data.thetas,
y=[+quantum_bound] * len(data.thetas),
line_color="gray",
name="Quantum limit",
legendgroup="quantum",
)
)

fig.add_trace(
go.Scatter(
mode="lines",
x=data.thetas,
y=[-quantum_bound] * len(data.thetas),
line_color="gray",
name="Quantum limit",
legendgroup="quantum",
showlegend=False,
)
)

fig.update_layout(
xaxis_title="Theta [rad]",
yaxis_title="Mermin polynomial value",
xaxis=dict(range=[min(data.thetas), max(data.thetas)]),
)
figures.append(fig)

return figures, ""


mermin = Routine(_acquisition, _fit, _plot)
84 changes: 84 additions & 0 deletions src/qibocal/protocols/two_qubit_interaction/mermin/pulses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from collections import defaultdict

import numpy as np
from qibolab.pulses import PulseSequence


def create_mermin_sequence(platform, qubits, theta=None):
"""Creates the pulse sequence to generate the bell states and with a theta-measurement"""

nqubits = len(qubits)
if not theta:
theta = ((nqubits - 1) * 0.25 * np.pi) % (2 * np.pi)

virtual_z_phases = defaultdict(int)
sequence = PulseSequence()

for qubit in qubits:
sequence.add(
platform.create_RX90_pulse(
qubit, start=0, relative_phase=virtual_z_phases[qubit] + np.pi / 2
)
)

# TODO: Not hardcode topology

# qubits[0] needs to be the center qubit where everything is connected
for i in range(1, len(qubits)):
(cz_sequence1, cz_virtual_z_phases) = platform.create_CZ_pulse_sequence(
[qubits[0]] + [qubits[i]], sequence.finish + 8 # TODO: ask for the 8
)
sequence.add(cz_sequence1)
for qubit in cz_virtual_z_phases:
virtual_z_phases[qubit] += cz_virtual_z_phases[qubit]

t = sequence.finish + 8

for i in range(1, len(qubits)):
sequence.add(
platform.create_RX90_pulse(
qubits[i],
start=t,
relative_phase=virtual_z_phases[qubits[i]] - np.pi / 2,
)
)

virtual_z_phases[qubits[0]] -= theta

return sequence, virtual_z_phases


def create_mermin_sequences(platform, qubits, readout_basis, theta):
"""Creates the pulse sequences needed for the 4 measurement settings for chsh."""

mermin_sequences = {}

for basis in readout_basis:
sequence, virtual_z_phases = create_mermin_sequence(
platform, qubits, theta=theta
)
# t = sequence.finish
for i, base in enumerate(basis):
if base == "X":
sequence.add(
platform.create_RX90_pulse(
qubits[i],
start=sequence.finish,
relative_phase=virtual_z_phases[qubits[i]] + np.pi / 2,
)
)
if base == "Y":
sequence.add(
platform.create_RX90_pulse(
qubits[i],
start=sequence.finish,
relative_phase=virtual_z_phases[qubits[i]],
)
)
measurement_start = sequence.finish

for qubit in qubits:
sequence.add(platform.create_MZ_pulse(qubit, start=measurement_start))

mermin_sequences[basis] = sequence
return mermin_sequences
Loading