Skip to content

Commit 3a7a951

Browse files
authored
Merge pull request #1459 from qiboteam/pauli_basis_speedup
Optimizations for the `qibo.quantum_info.basis.pauli_basis` and `vectorization` function
2 parents b53ce52 + 30fb9bb commit 3a7a951

File tree

3 files changed

+70
-36
lines changed

3 files changed

+70
-36
lines changed

src/qibo/quantum_info/basis.py

+21-24
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from functools import reduce
21
from itertools import product
32
from typing import Optional
43

@@ -92,43 +91,41 @@ def pauli_basis(
9291
backend = _check_backend(backend)
9392

9493
pauli_labels = {"I": matrices.I, "X": matrices.X, "Y": matrices.Y, "Z": matrices.Z}
95-
basis_single = [pauli_labels[label] for label in pauli_order]
94+
dim = 2**nqubits
95+
basis_single = backend.cast([pauli_labels[label] for label in pauli_order])
96+
einsum = np.einsum if backend.name == "tensorflow" else backend.np.einsum
9697

9798
if nqubits > 1:
98-
basis_full = list(product(basis_single, repeat=nqubits))
99-
basis_full = [reduce(np.kron, row) for row in basis_full]
99+
input_indices = [range(3 * i, 3 * (i + 1)) for i in range(nqubits)]
100+
output_indices = (i for indices in zip(*input_indices) for i in indices)
101+
operands = [basis_single for _ in range(nqubits)]
102+
inputs = [item for pair in zip(operands, input_indices) for item in pair]
103+
basis_full = einsum(*inputs, output_indices).reshape(4**nqubits, dim, dim)
100104
else:
101105
basis_full = basis_single
102106

103-
basis_full = backend.cast(basis_full, dtype=basis_full[0].dtype)
104-
105107
if vectorize and sparse:
106-
basis, indexes = [], []
107-
for row in basis_full:
108-
row = vectorization(row, order=order, backend=backend)
109-
row_indexes = backend.np.flatnonzero(row)
110-
indexes.append(row_indexes)
111-
basis.append(row[row_indexes])
112-
del row
108+
if backend.name == "tensorflow":
109+
nonzero = np.nonzero
110+
elif backend.name == "pytorch":
111+
nonzero = lambda x: backend.np.nonzero(x, as_tuple=True)
112+
else:
113+
nonzero = backend.np.nonzero
114+
basis = vectorization(basis_full, order=order, backend=backend)
115+
indices = nonzero(basis)
116+
basis = basis[indices].reshape(-1, dim)
117+
indices = indices[1].reshape(-1, dim)
118+
113119
elif vectorize and not sparse:
114-
basis = [
115-
vectorization(
116-
backend.cast(matrix, dtype=matrix.dtype), order=order, backend=backend
117-
)
118-
for matrix in basis_full
119-
]
120+
basis = vectorization(basis_full, order=order, backend=backend)
120121
else:
121122
basis = basis_full
122123

123-
basis = backend.cast(basis, dtype=basis[0].dtype)
124-
125124
if normalize:
126125
basis = basis / np.sqrt(2**nqubits)
127126

128127
if vectorize and sparse:
129-
indexes = backend.cast(indexes, dtype=indexes[0][0].dtype)
130-
131-
return basis, indexes
128+
return basis, indices
132129

133130
return basis
134131

src/qibo/quantum_info/superoperator_transformations.py

+24-12
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@ def vectorization(state, order: str = "row", backend=None):
2727
.. math::
2828
|\\rho) = \\sum_{k, l} \\, \\rho_{kl} \\, \\ket{l} \\otimes \\ket{k}
2929
30+
If ``state`` is a 3-dimensional tensor it is interpreted as a batch of states.
3031
Args:
31-
state: state vector or density matrix.
32+
state: statevector, density matrix, an array of statevectors, or an array of density matrices.
3233
order (str, optional): If ``"row"``, vectorization is performed
3334
row-wise. If ``"column"``, vectorization is performed
3435
column-wise. If ``"system"``, a block-vectorization is
@@ -41,13 +42,13 @@ def vectorization(state, order: str = "row", backend=None):
4142
ndarray: Liouville representation of ``state``.
4243
"""
4344
if (
44-
(len(state.shape) >= 3)
45+
(len(state.shape) > 3)
4546
or (len(state) == 0)
4647
or (len(state.shape) == 2 and state.shape[0] != state.shape[1])
4748
):
4849
raise_error(
4950
TypeError,
50-
f"Object must have dims either (k,) or (k,k), but have dims {state.shape}.",
51+
f"Object must have dims either (k,), (k, k), (N, 1, k) or (N, k, k), but have dims {state.shape}.",
5152
)
5253

5354
if not isinstance(order, str):
@@ -63,25 +64,36 @@ def vectorization(state, order: str = "row", backend=None):
6364

6465
backend = _check_backend(backend)
6566

67+
dims = state.shape[-1]
68+
6669
if len(state.shape) == 1:
6770
state = backend.np.outer(state, backend.np.conj(state))
71+
elif len(state.shape) == 3 and state.shape[1] == 1:
72+
state = backend.np.einsum(
73+
"aij,akl->aijkl", state, backend.np.conj(state)
74+
).reshape(state.shape[0], dims, dims)
6875

6976
if order == "row":
70-
state = backend.np.reshape(state, (1, -1))[0]
77+
state = backend.np.reshape(state, (-1, dims**2))
7178
elif order == "column":
72-
state = state.T
73-
state = backend.np.reshape(state, (1, -1))[0]
79+
indices = list(range(len(state.shape)))
80+
indices[-2:] = reversed(indices[-2:])
81+
state = backend.np.transpose(state, indices)
82+
state = backend.np.reshape(state, (-1, dims**2))
7483
else:
75-
dim = len(state)
76-
nqubits = int(np.log2(dim))
84+
nqubits = int(np.log2(state.shape[-1]))
7785

78-
new_axis = []
86+
new_axis = [0]
7987
for qubit in range(nqubits):
80-
new_axis += [qubit + nqubits, qubit]
88+
new_axis.extend([qubit + nqubits + 1, qubit + 1])
8189

82-
state = backend.np.reshape(state, [2] * 2 * nqubits)
90+
state = backend.np.reshape(state, [-1] + [2] * 2 * nqubits)
8391
state = backend.np.transpose(state, new_axis)
84-
state = backend.np.reshape(state, (-1,))
92+
state = backend.np.reshape(state, (-1, 2 ** (2 * nqubits)))
93+
94+
state = backend.np.squeeze(
95+
state, axis=tuple(i for i, ax in enumerate(state.shape) if ax == 1)
96+
)
8597

8698
return state
8799

tests/test_quantum_info_superoperator_transformations.py

+25
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,31 @@ def test_vectorization(backend, nqubits, order, statevector):
169169
backend.assert_allclose(matrix, matrix_test, atol=PRECISION_TOL)
170170

171171

172+
@pytest.mark.parametrize("order", ["row", "column", "system"])
173+
@pytest.mark.parametrize("nqubits", [1, 2, 3])
174+
@pytest.mark.parametrize("statevector", [True, False])
175+
def test_batched_vectorization(backend, nqubits, order, statevector):
176+
if statevector:
177+
state = backend.cast(
178+
[random_statevector(2**nqubits, 42, backend=backend) for _ in range(3)]
179+
).reshape(3, 1, -1)
180+
else:
181+
state = backend.cast(
182+
[
183+
random_density_matrix(2**nqubits, seed=42, backend=backend)
184+
for _ in range(3)
185+
]
186+
)
187+
188+
batched_vec = vectorization(state, order=order, backend=backend)
189+
for i, element in enumerate(state):
190+
if statevector:
191+
element = element.ravel()
192+
backend.assert_allclose(
193+
batched_vec[i], vectorization(element, order=order, backend=backend)
194+
)
195+
196+
172197
@pytest.mark.parametrize("order", ["row", "column", "system"])
173198
@pytest.mark.parametrize("nqubits", [2, 3, 4, 5])
174199
def test_unvectorization(backend, nqubits, order):

0 commit comments

Comments
 (0)