Skip to content

Commit d643c23

Browse files
authored
Update SamplerQNN docs (qiskit-community#886)
1 parent c2581f1 commit d643c23

File tree

1 file changed

+50
-23
lines changed

1 file changed

+50
-23
lines changed

qiskit_machine_learning/neural_networks/sampler_qnn.py

+50-23
Original file line numberDiff line numberDiff line change
@@ -70,63 +70,87 @@ class SamplerQNN(NeuralNetwork):
7070
from the :class:`~qiskit_machine_learning.circuit.library.QNNCircuit`.
7171
7272
The output can be set up in different formats, and an optional post-processing step
73-
can be used to interpret the sampler's output in a particular context (e.g. mapping the
74-
resulting bitstring to match the number of classes).
73+
can be used to interpret or map the sampler's raw output in a particular context (e.g. mapping
74+
the resulting bitstring to match the number of classes) via an ``interpret`` function.
7575
76-
In this example the network maps the output of the quantum circuit to two classes via a custom
77-
`interpret` function:
76+
The ``output_shape`` parameter defines the shape of the output array after applying the
77+
interpret function, and can be set following the guidelines below.
7878
79-
.. code-block::
79+
* **Default behavior:** if no interpret function is provided, the default output_shape is
80+
``2**num_qubits``, which corresponds to the number of possible bit-strings for the given
81+
number of qubits.
82+
* **Custom interpret function:** when using a custom interpret function, you must specify
83+
``output_shape`` to match the expected output of the interpret function. For instance, if
84+
your interpret function maps bit-strings to two classes, you should set ``output_shape=2``.
85+
* **Number of classical registers:** if you want to reshape the output by the number of
86+
classical registers, set ``output_shape=2**circuit.num_clbits``. This is useful when
87+
the number of classical registers differs from the number of qubits.
88+
* **Tuple shape:** if the interpret function returns a tuple, ``output_shape`` should be a
89+
``tuple`` that matches the dimensions of the interpreted output.
90+
91+
In this example, the network maps the output of the quantum circuit to two classes via a custom
92+
``interpret`` function:
93+
94+
95+
.. code-block:: python
8096
8197
from qiskit import QuantumCircuit
8298
from qiskit.circuit.library import ZZFeatureMap, RealAmplitudes
8399
from qiskit_machine_learning.circuit.library import QNNCircuit
84-
85100
from qiskit_machine_learning.neural_networks import SamplerQNN
86101
87102
num_qubits = 2
88103
104+
# Define a custom interpret function that calculates the parity of the bitstring
89105
def parity(x):
90106
return f"{bin(x)}".count("1") % 2
91107
92-
# Using the QNNCircuit:
93-
# Create a parameterized 2 qubit circuit composed of the default ZZFeatureMap feature map
94-
# and RealAmplitudes ansatz.
108+
# Example 1: Using the QNNCircuit class
109+
# QNNCircuit automatically combines a feature map and an ansatz into a single circuit
95110
qnn_qc = QNNCircuit(num_qubits)
96111
97112
qnn = SamplerQNN(
98-
circuit=qnn_qc,
113+
circuit=qnn_qc, # Note that this is a QNNCircuit instance
99114
interpret=parity,
100-
output_shape=2
115+
output_shape=2 # Reshape by the number of classical registers
101116
)
102117
118+
# Do a forward pass with input data and custom weights
103119
qnn.forward(input_data=[1, 2], weights=[1, 2, 3, 4, 5, 6, 7, 8])
104120
105-
# Explicitly specifying the ansatz and feature map:
121+
# Example 2: Explicitly specifying the feature map and ansatz
122+
# Create a feature map and an ansatz separately
106123
feature_map = ZZFeatureMap(feature_dimension=num_qubits)
107124
ansatz = RealAmplitudes(num_qubits=num_qubits)
108125
126+
# Compose the feature map and ansatz manually (otherwise done within QNNCircuit)
109127
qc = QuantumCircuit(num_qubits)
110128
qc.compose(feature_map, inplace=True)
111129
qc.compose(ansatz, inplace=True)
112130
113131
qnn = SamplerQNN(
114-
circuit=qc,
132+
circuit=qc, # Note that this is a QuantumCircuit instance
115133
input_params=feature_map.parameters,
116134
weight_params=ansatz.parameters,
117135
interpret=parity,
118-
output_shape=2
136+
output_shape=2 # Reshape by the number of classical registers
119137
)
120138
139+
# Perform a forward pass with input data and weights
121140
qnn.forward(input_data=[1, 2], weights=[1, 2, 3, 4, 5, 6, 7, 8])
122141
142+
123143
The following attributes can be set via the constructor but can also be read and
124144
updated once the SamplerQNN object has been constructed.
125145
126146
Attributes:
127147
128-
sampler (BaseSampler): The sampler primitive used to compute the neural network's results.
129-
gradient (BaseSamplerGradient): A sampler gradient to be used for the backward pass.
148+
sampler (BaseSampler): The sampler primitive used to compute the neural network's
149+
results. If not provided, a default instance of the reference sampler defined by
150+
:class:`~qiskit.primitives.Sampler` will be used.
151+
gradient (BaseSamplerGradient): An optional sampler gradient used for the backward
152+
pass. If not provided, a default instance of
153+
:class:`~qiskit_machine_learning.gradients.ParamShiftSamplerGradient` will be used.
130154
"""
131155

132156
def __init__(
@@ -173,8 +197,8 @@ def __init__(
173197
sparse: Returns whether the output is sparse or not.
174198
interpret: A callable that maps the measured integer to another unsigned integer or tuple
175199
of unsigned integers. These are used as new indices for the (potentially sparse)
176-
output array. If no interpret function is passed, then an identity function will be
177-
used by this neural network.
200+
output array. If the interpret function is ``None``, then an identity function will be
201+
used by this neural network: ``lambda x: x`` (default).
178202
output_shape: The output shape of the custom interpretation. For SamplerV1, it is ignored
179203
if no custom interpret method is provided where the shape is taken to be
180204
``2^circuit.num_qubits``.
@@ -190,7 +214,7 @@ def __init__(
190214
Raises:
191215
QiskitMachineLearningError: Invalid parameter values.
192216
"""
193-
# set primitive, provide default
217+
# Set primitive, provide default
194218
if sampler is None:
195219
sampler = Sampler()
196220

@@ -226,8 +250,10 @@ def __init__(
226250
if sparse:
227251
_optionals.HAS_SPARSE.require_now("DOK")
228252

253+
self._interpret = interpret
229254
self.set_interpret(interpret, output_shape)
230-
# set gradient
255+
256+
# Set gradient
231257
if gradient is None:
232258
if isinstance(sampler, BaseSamplerV1):
233259
gradient = ParamShiftSamplerGradient(sampler=self.sampler)
@@ -283,7 +309,7 @@ def set_interpret(
283309
interpret: Callable[[int], int | tuple[int, ...]] | None = None,
284310
output_shape: int | tuple[int, ...] | None = None,
285311
) -> None:
286-
"""Change 'interpret' and corresponding 'output_shape'.
312+
"""Change ``interpret`` and corresponding ``output_shape``.
287313
288314
Args:
289315
interpret: A callable that maps the measured integer to another unsigned integer or
@@ -308,13 +334,13 @@ def _compute_output_shape(
308334
QiskitMachineLearningError: If an invalid ``sampler``provided.
309335
"""
310336

311-
# this definition is required by mypy
337+
# This definition is required by mypy
312338
output_shape_: tuple[int, ...] = (-1,)
313339

314340
if interpret is not None:
315341
if output_shape is None:
316342
raise QiskitMachineLearningError(
317-
"No output shape given; it's required when using custom interpret!"
343+
"No output shape given, but it's required when using custom interpret function."
318344
)
319345
if isinstance(output_shape, Integral):
320346
output_shape = int(output_shape)
@@ -354,6 +380,7 @@ def _postprocess(self, num_samples: int, result: SamplerResult) -> np.ndarray |
354380
else:
355381
# Fallback to 'c' if 'meas' is not available.
356382
bitstring_counts = result[i].data.c.get_counts()
383+
357384
# Normalize the counts to probabilities
358385
total_shots = sum(bitstring_counts.values())
359386
probabilities = {k: v / total_shots for k, v in bitstring_counts.items()}

0 commit comments

Comments
 (0)