@@ -27,8 +27,9 @@ def vectorization(state, order: str = "row", backend=None):
27
27
.. math::
28
28
|\\ rho) = \\ sum_{k, l} \\ , \\ rho_{kl} \\ , \\ ket{l} \\ otimes \\ ket{k}
29
29
30
+ If ``state`` is a 3-dimensional tensor it is interpreted as a batch of states.
30
31
Args:
31
- state: state vector or density matrix .
32
+ state: statevector, density matrix, an array of statevectors, or an array of density matrices .
32
33
order (str, optional): If ``"row"``, vectorization is performed
33
34
row-wise. If ``"column"``, vectorization is performed
34
35
column-wise. If ``"system"``, a block-vectorization is
@@ -41,13 +42,13 @@ def vectorization(state, order: str = "row", backend=None):
41
42
ndarray: Liouville representation of ``state``.
42
43
"""
43
44
if (
44
- (len (state .shape ) >= 3 )
45
+ (len (state .shape ) > 3 )
45
46
or (len (state ) == 0 )
46
47
or (len (state .shape ) == 2 and state .shape [0 ] != state .shape [1 ])
47
48
):
48
49
raise_error (
49
50
TypeError ,
50
- f"Object must have dims either (k,) or (k, k), but have dims { state .shape } ." ,
51
+ f"Object must have dims either (k,), (k, k), (N, 1, k) or (N, k, k), but have dims { state .shape } ." ,
51
52
)
52
53
53
54
if not isinstance (order , str ):
@@ -63,25 +64,36 @@ def vectorization(state, order: str = "row", backend=None):
63
64
64
65
backend = _check_backend (backend )
65
66
67
+ dims = state .shape [- 1 ]
68
+
66
69
if len (state .shape ) == 1 :
67
70
state = backend .np .outer (state , backend .np .conj (state ))
71
+ elif len (state .shape ) == 3 and state .shape [1 ] == 1 :
72
+ state = backend .np .einsum (
73
+ "aij,akl->aijkl" , state , backend .np .conj (state )
74
+ ).reshape (state .shape [0 ], dims , dims )
68
75
69
76
if order == "row" :
70
- state = backend .np .reshape (state , (1 , - 1 ))[ 0 ]
77
+ state = backend .np .reshape (state , (- 1 , dims ** 2 ))
71
78
elif order == "column" :
72
- state = state .T
73
- state = backend .np .reshape (state , (1 , - 1 ))[0 ]
79
+ indices = list (range (len (state .shape )))
80
+ indices [- 2 :] = reversed (indices [- 2 :])
81
+ state = backend .np .transpose (state , indices )
82
+ state = backend .np .reshape (state , (- 1 , dims ** 2 ))
74
83
else :
75
- dim = len (state )
76
- nqubits = int (np .log2 (dim ))
84
+ nqubits = int (np .log2 (state .shape [- 1 ]))
77
85
78
- new_axis = []
86
+ new_axis = [0 ]
79
87
for qubit in range (nqubits ):
80
- new_axis += [qubit + nqubits , qubit ]
88
+ new_axis . extend ( [qubit + nqubits + 1 , qubit + 1 ])
81
89
82
- state = backend .np .reshape (state , [2 ] * 2 * nqubits )
90
+ state = backend .np .reshape (state , [- 1 ] + [ 2 ] * 2 * nqubits )
83
91
state = backend .np .transpose (state , new_axis )
84
- state = backend .np .reshape (state , (- 1 ,))
92
+ state = backend .np .reshape (state , (- 1 , 2 ** (2 * nqubits )))
93
+
94
+ state = backend .np .squeeze (
95
+ state , axis = tuple (i for i , ax in enumerate (state .shape ) if ax == 1 )
96
+ )
85
97
86
98
return state
87
99
0 commit comments