Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove circular dependency between MeasurementResult and M gate #1465

Merged
merged 3 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/qibo/gates/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
self.target_qubits = tuple(q)
self.register_name = register_name
self.collapse = collapse
self.result = MeasurementResult(self)
self.result = MeasurementResult(self.target_qubits)
# list of measurement pulses implementing the gate
# relevant for experiments only
self.pulses = None
Expand Down
14 changes: 7 additions & 7 deletions src/qibo/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ class MeasurementResult:
to use for calculations.
"""

def __init__(self, gate):
self.measurement_gate = gate
def __init__(self, qubits):
self.target_qubits = qubits
self.circuit = None

self._samples = None
Expand All @@ -97,7 +97,7 @@ def __init__(self, gate):
self._symbols = None

def __repr__(self):
qubits = self.measurement_gate.qubits
qubits = self.target_qubits
nshots = self.nshots
return f"MeasurementResult(qubits={qubits}, nshots={nshots})"

Expand All @@ -115,7 +115,7 @@ def nshots(self) -> int:

def add_shot(self, probs, backend=None):
backend = _check_backend(backend)
qubits = sorted(self.measurement_gate.target_qubits)
qubits = sorted(self.target_qubits)
shot = backend.sample_shots(probs, 1)
bshot = backend.samples_to_binary(shot, len(qubits))
if self._samples:
Expand Down Expand Up @@ -153,7 +153,7 @@ def symbols(self):
These symbols are useful for conditioning parametrized gates on measurement outcomes.
"""
if self._symbols is None:
qubits = self.measurement_gate.target_qubits
qubits = self.target_qubits
self._symbols = [MeasurementSymbol(i, self) for i in range(len(qubits))]

return self._symbols
Expand Down Expand Up @@ -186,7 +186,7 @@ def samples(self, binary=True, registers=False, backend=None):
if binary:
return self._samples

qubits = self.measurement_gate.target_qubits
qubits = self.target_qubits
return backend.samples_to_decimal(self._samples, len(qubits))

def frequencies(self, binary=True, registers=False, backend=None):
Expand All @@ -213,7 +213,7 @@ def frequencies(self, binary=True, registers=False, backend=None):
self.samples(binary=False)
)
if binary:
qubits = self.measurement_gate.target_qubits
qubits = self.target_qubits
return frequencies_to_binary(self._frequencies, len(qubits))

return self._frequencies
Expand Down
4 changes: 2 additions & 2 deletions tests/test_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def test_measurementsymbol_pickling(backend):

def test_measurementresult_nshots(backend):
gate = gates.M(*range(3))
result = MeasurementResult(gate)
result = MeasurementResult(gate.qubits)
# nshots starting from samples
nshots = 10
samples = backend.cast(
Expand All @@ -487,7 +487,7 @@ def test_measurementresult_nshots(backend):
result.register_samples(samples)
assert result.nshots == nshots
# nshots starting from frequencies
result = MeasurementResult(gate)
result = MeasurementResult(gate.qubits)
states, counts = np.unique(samples, axis=0, return_counts=True)
to_str = lambda x: [str(item) for item in x]
states = ["".join(to_str(s)) for s in states.tolist()]
Expand Down
4 changes: 2 additions & 2 deletions tests/test_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@


def test_measurement_result_repr():
result = MeasurementResult(gates.M(0))
result = MeasurementResult(gates.M(0).target_qubits)
assert str(result) == "MeasurementResult(qubits=(0,), nshots=None)"


def test_measurement_result_error():
result = MeasurementResult(gates.M(0))
result = MeasurementResult(gates.M(0).qubits)
with pytest.raises(RuntimeError):
samples = result.samples()

Expand Down