1
1
import os
2
2
from importlib import import_module
3
3
4
+ import networkx as nx
4
5
import numpy as np
5
6
6
7
from qibo .backends .abstract import Backend
@@ -65,10 +66,11 @@ def list_available(self) -> dict:
65
66
return available_backends
66
67
67
68
68
- class GlobalBackend (NumpyBackend ):
69
- """The global backend will be used as default by ``circuit.execute()``."""
69
+ class _Global :
70
+ _backend = None
71
+ _transpiler = None
72
+ # TODO: resolve circular import with qibo.transpiler.pipeline.Passes
70
73
71
- _instance = None
72
74
_dtypes = {"double" : "complex128" , "single" : "complex64" }
73
75
_default_order = [
74
76
{"backend" : "qibojit" , "platform" : "cupy" },
@@ -78,39 +80,88 @@ class GlobalBackend(NumpyBackend):
78
80
{"backend" : "pytorch" },
79
81
]
80
82
81
- def __new__ (cls ):
82
- if cls ._instance is not None :
83
- return cls ._instance
83
+ @classmethod
84
+ def backend (cls ):
85
+ """Get the current backend. If no backend is set, it will create one."""
86
+ if cls ._backend is not None :
87
+ return cls ._backend
88
+ cls ._backend = cls ._create_backend ()
89
+ log .info (f"Using { cls ._backend } backend on { cls ._backend .device } " )
90
+ return cls ._backend
84
91
85
- backend = os .environ .get ("QIBO_BACKEND" )
86
- if backend : # pragma: no cover
92
+ @classmethod
93
+ def _create_backend (cls ):
94
+ backend_env = os .environ .get ("QIBO_BACKEND" )
95
+ if backend_env : # pragma: no cover
87
96
# Create backend specified by user
88
97
platform = os .environ .get ("QIBO_PLATFORM" )
89
- cls . _instance = construct_backend (backend , platform = platform )
98
+ backend = construct_backend (backend_env , platform = platform )
90
99
else :
91
100
# Create backend according to default order
92
101
for kwargs in cls ._default_order :
93
102
try :
94
- cls . _instance = construct_backend (** kwargs )
103
+ backend = construct_backend (** kwargs )
95
104
break
96
105
except (ImportError , MissingBackend ):
97
106
pass
98
107
99
- if cls . _instance is None : # pragma: no cover
108
+ if backend is None : # pragma: no cover
100
109
raise_error (RuntimeError , "No backends available." )
110
+ return backend
111
+
112
+ @classmethod
113
+ def set_backend (cls , backend , ** kwargs ):
114
+ cls ._backend = construct_backend (backend , ** kwargs )
115
+ log .info (f"Using { cls ._backend } backend on { cls ._backend .device } " )
101
116
102
- log .info (f"Using { cls ._instance } backend on { cls ._instance .device } " )
103
- return cls ._instance
117
+ @classmethod
118
+ def transpiler (cls ):
119
+ """Get the current transpiler. If no transpiler is set, it will create one."""
120
+ if cls ._transpiler is not None :
121
+ return cls ._transpiler
122
+
123
+ cls ._transpiler = cls ._default_transpiler ()
124
+ return cls ._transpiler
125
+
126
+ @classmethod
127
+ def set_transpiler (cls , transpiler ):
128
+ cls ._transpiler = transpiler
129
+ # TODO: check if transpiler is valid on the backend
104
130
105
131
@classmethod
106
- def set_backend (cls , backend , ** kwargs ): # pragma: no cover
132
+ def _default_transpiler (cls ):
133
+ from qibo .transpiler .optimizer import Preprocessing
134
+ from qibo .transpiler .pipeline import Passes
135
+ from qibo .transpiler .placer import Trivial
136
+ from qibo .transpiler .router import Sabre
137
+ from qibo .transpiler .unroller import NativeGates , Unroller
138
+
139
+ qubits = cls ._backend .qubits
140
+ natives = cls ._backend .natives
141
+ connectivity_edges = cls ._backend .connectivity
107
142
if (
108
- cls . _instance is None
109
- or cls . _instance . name != backend
110
- or cls . _instance . platform != kwargs . get ( "platform" )
143
+ qubits is not None
144
+ and natives is not None
145
+ and connectivity_edges is not None
111
146
):
112
- cls ._instance = construct_backend (backend , ** kwargs )
113
- log .info (f"Using { cls ._instance } backend on { cls ._instance .device } " )
147
+ # only for q{i} naming
148
+ node_mapping = {q : i for i , q in enumerate (qubits )}
149
+ edges = [
150
+ (node_mapping [e [0 ]], node_mapping [e [1 ]]) for e in connectivity_edges
151
+ ]
152
+ connectivity = nx .Graph (edges )
153
+
154
+ return Passes (
155
+ connectivity = connectivity ,
156
+ passes = [
157
+ Preprocessing (connectivity ),
158
+ Trivial (connectivity ),
159
+ Sabre (connectivity ),
160
+ Unroller (NativeGates [natives ]),
161
+ ],
162
+ )
163
+
164
+ return Passes (passes = [])
114
165
115
166
116
167
class QiboMatrices :
@@ -147,53 +198,97 @@ def create(self, dtype):
147
198
148
199
149
200
def get_backend ():
150
- return str (GlobalBackend ())
201
+ """Get the current backend."""
202
+ return _Global .backend ()
151
203
152
204
153
205
def set_backend (backend , ** kwargs ):
154
- GlobalBackend .set_backend (backend , ** kwargs )
206
+ """Set the current backend.
207
+
208
+ Args:
209
+ backend (str): Name of the backend to use.
210
+ kwargs (dict): Additional arguments for the backend.
211
+ """
212
+ _Global .set_backend (backend , ** kwargs )
213
+
214
+
215
+ def get_transpiler ():
216
+ """Get the current transpiler."""
217
+ return _Global .transpiler ()
218
+
219
+
220
+ def get_transpiler_name ():
221
+ """Get the name of the current transpiler as a string."""
222
+ return str (_Global .transpiler ())
223
+
224
+
225
+ def set_transpiler (transpiler ):
226
+ """Set the current transpiler.
227
+
228
+ Args:
229
+ transpiler (Passes): The transpiler to use.
230
+ """
231
+ _Global .set_transpiler (transpiler )
155
232
156
233
157
234
def get_precision ():
158
- return GlobalBackend ().precision
235
+ """Get the precision of the backend."""
236
+ return get_backend ().precision
159
237
160
238
161
239
def set_precision (precision ):
162
- GlobalBackend ().set_precision (precision )
163
- matrices .create (GlobalBackend ().dtype )
240
+ """Set the precision of the backend.
241
+
242
+ Args:
243
+ precision (str): Precision to use.
244
+ """
245
+ get_backend ().set_precision (precision )
246
+ matrices .create (get_backend ().dtype )
164
247
165
248
166
249
def get_device ():
167
- return GlobalBackend ().device
250
+ """Get the device of the backend."""
251
+ return get_backend ().device
168
252
169
253
170
254
def set_device (device ):
255
+ """Set the device of the backend.
256
+
257
+ Args:
258
+ device (str): Device to use.
259
+ """
171
260
parts = device [1 :].split (":" )
172
261
if device [0 ] != "/" or len (parts ) < 2 or len (parts ) > 3 :
173
262
raise_error (
174
263
ValueError ,
175
264
"Device name should follow the pattern: /{device type}:{device number}." ,
176
265
)
177
- backend = GlobalBackend ()
266
+ backend = get_backend ()
178
267
backend .set_device (device )
179
268
log .info (f"Using { backend } backend on { backend .device } " )
180
269
181
270
182
271
def get_threads ():
183
- return GlobalBackend ().nthreads
272
+ """Get the number of threads used by the backend."""
273
+ return get_backend ().nthreads
184
274
185
275
186
276
def set_threads (nthreads ):
277
+ """Set the number of threads used by the backend.
278
+
279
+ Args:
280
+ nthreads (int): Number of threads to use.
281
+ """
187
282
if not isinstance (nthreads , int ):
188
283
raise_error (TypeError , "Number of threads must be integer." )
189
284
if nthreads < 1 :
190
285
raise_error (ValueError , "Number of threads must be positive." )
191
- GlobalBackend ().set_threads (nthreads )
286
+ get_backend ().set_threads (nthreads )
192
287
193
288
194
289
def _check_backend (backend ):
195
290
if backend is None :
196
- return GlobalBackend ()
291
+ return get_backend ()
197
292
198
293
return backend
199
294
0 commit comments