Skip to content

Commit 87e2a93

Browse files
committed
Add Qiskit native QPY ParameterExpression serialization
With the release of symengine 0.13.0 we discovered a version dependence on the payload format used for serializing symengine expressions. This was worked around in Qiskit#13251 but this is not a sustainable solution and only works for symengine 0.11.0 and 0.13.0 (there was no 0.12.0). While there was always the option to use sympy to serialize the underlying symbolic expression (there is a `use_symengine` flag on `qpy.dumps` you can set to `False` to do this) the sympy serialzation has several tradeoffs most importantly is much higher runtime overhead. To solve the issue moving forward a qiskit native representation of the parameter expression object is necessary for serialization. This commit bumps the QPY format version to 13 and adds a new serialization format for ParameterExpression objects. This new format is a serialization of the API calls made to ParameterExpression that resulted in the creation of the underlying object. To facilitate this the ParameterExpression class is expanded to store an internal "replay" record of the API calls used to construct the ParameterExpression object. This internal list is what gets serialized by QPY and then on deserialization the "replay" is replayed to reconstruct the expression object. This is a different approach to the previous QPY representations of the ParameterExpression objects which instead represented the internal state stored in the ParameterExpression object with the symbolic expression from symengine (or a sympy copy of the expression). Doing this directly in Qiskit isn't viable though because symengine's internal expression tree is not exposed to Python directly. There isn't any method (private or public) to walk the expression tree to construct a serialization format based off of it. Converting symengine to a sympy expression and then using sympy's API to walk the expression tree is a possibility but that would tie us to sympy which would be problematic for Qiskit#13267 and Qiskit#13131, have significant runtime overhead, and it would be just easier to rely on sympy's native serialization tools. The tradeoff with this approach is that it does increase the memory overhead of the `ParameterExpression` class because for each element in the expression we have to store a record of it. Depending on the depth of the expression tree this also could be a lot larger than symengine's internal representation as we store the raw api calls made to create the ParameterExpression but symengine is likely simplifying it's internal representation as it builds it out. But I personally think this tradeoff is worthwhile as it ties the serialization format to the Qiskit objects instead of relying on a 3rd party library. This also gives us the flexibility of changing the internal symbolic expression library internally in the future if we decide to stop using symengine at any point. Fixes Qiskit#13252
1 parent 9a1d8d3 commit 87e2a93

File tree

5 files changed

+483
-43
lines changed

5 files changed

+483
-43
lines changed

qiskit/circuit/parameter.py

+2
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def __init__(
8787
self._hash = hash((self._parameter_keys, self._symbol_expr))
8888
self._parameter_symbols = {self: symbol}
8989
self._name_map = None
90+
self._qpy_replay = []
9091

9192
def assign(self, parameter, value):
9293
if parameter != self:
@@ -172,3 +173,4 @@ def __setstate__(self, state):
172173
self._hash = hash((self._parameter_keys, self._symbol_expr))
173174
self._parameter_symbols = {self: self._symbol_expr}
174175
self._name_map = None
176+
self._qpy_replay = []

qiskit/circuit/parameterexpression.py

+133-33
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
"""
1515

1616
from __future__ import annotations
17+
18+
from dataclasses import dataclass
19+
from enum import IntEnum
1720
from typing import Callable, Union
1821

1922
import numbers
@@ -30,12 +33,79 @@
3033
ParameterValueType = Union["ParameterExpression", float]
3134

3235

36+
class _OPCode(IntEnum):
37+
ADD = 0
38+
SUB = 1
39+
MUL = 2
40+
DIV = 3
41+
POW = 4
42+
SIN = 5
43+
COS = 6
44+
TAN = 7
45+
ASIN = 8
46+
ACOS = 9
47+
EXP = 10
48+
LOG = 11
49+
SIGN = 12
50+
DERIV = 13
51+
CONJ = 14
52+
SUBSTITUTE = 15
53+
ABS = 16
54+
ATAN = 17
55+
56+
57+
_OP_CODE_MAP = (
58+
"__add__",
59+
"__sub__",
60+
"__mul__",
61+
"__truediv__",
62+
"__pow__",
63+
"sin",
64+
"cos",
65+
"tan",
66+
"arcsin",
67+
"arccos",
68+
"exp",
69+
"log",
70+
"sign",
71+
"gradient",
72+
"conjugate",
73+
"subs",
74+
"abs",
75+
"arctan",
76+
)
77+
78+
79+
def op_code_to_method(op_code: _OPCode):
80+
"""Return the method name for a given op_code."""
81+
return _OP_CODE_MAP[op_code]
82+
83+
84+
@dataclass
85+
class _INSTRUCTION:
86+
op: _OPCode
87+
lhs: ParameterValueType
88+
rhs: ParameterValueType | None = None
89+
90+
91+
@dataclass
92+
class _SUBS:
93+
binds: dict
94+
op: _OPCode = _OPCode.SUBSTITUTE
95+
96+
3397
class ParameterExpression:
3498
"""ParameterExpression class to enable creating expressions of Parameters."""
3599

36-
__slots__ = ["_parameter_symbols", "_parameter_keys", "_symbol_expr", "_name_map"]
100+
__slots__ = [
101+
"_parameter_symbols",
102+
"_parameter_keys",
103+
"_symbol_expr",
104+
"_name_map",
105+
"_qpy_replay",
106+
]
37107

38-
def __init__(self, symbol_map: dict, expr):
108+
def __init__(self, symbol_map: dict, expr, *, _qpy_replay=None):
39109
"""Create a new :class:`ParameterExpression`.
40110
41111
Not intended to be called directly, but to be instantiated via operations
@@ -54,6 +124,10 @@ def __init__(self, symbol_map: dict, expr):
54124
self._parameter_keys = frozenset(p._hash_key() for p in self._parameter_symbols)
55125
self._symbol_expr = expr
56126
self._name_map: dict | None = None
127+
if _qpy_replay is not None:
128+
self._qpy_replay = _qpy_replay
129+
else:
130+
self._qpy_replay = []
57131

58132
@property
59133
def parameters(self) -> set:
@@ -69,8 +143,11 @@ def _names(self) -> dict:
69143

70144
def conjugate(self) -> "ParameterExpression":
71145
"""Return the conjugate."""
146+
new_op = _INSTRUCTION(_OPCode.CONJ, self)
147+
new_replay = self._qpy_replay.copy()
148+
new_replay.append(new_op)
72149
conjugated = ParameterExpression(
73-
self._parameter_symbols, symengine.conjugate(self._symbol_expr)
150+
self._parameter_symbols, symengine.conjugate(self._symbol_expr), _qpy_replay=new_replay
74151
)
75152
return conjugated
76153

@@ -117,6 +194,7 @@ def bind(
117194
self._raise_if_passed_unknown_parameters(parameter_values.keys())
118195
self._raise_if_passed_nan(parameter_values)
119196

197+
new_op = _SUBS(parameter_values)
120198
symbol_values = {}
121199
for parameter, value in parameter_values.items():
122200
if (param_expr := self._parameter_symbols.get(parameter)) is not None:
@@ -143,7 +221,12 @@ def bind(
143221
f"(Expression: {self}, Bindings: {parameter_values})."
144222
)
145223

146-
return ParameterExpression(free_parameter_symbols, bound_symbol_expr)
224+
new_replay = self._qpy_replay.copy()
225+
new_replay.append(new_op)
226+
227+
return ParameterExpression(
228+
free_parameter_symbols, bound_symbol_expr, _qpy_replay=new_replay
229+
)
147230

148231
def subs(
149232
self, parameter_map: dict, allow_unknown_parameters: bool = False
@@ -175,6 +258,7 @@ def subs(
175258
for p in replacement_expr.parameters
176259
}
177260
self._raise_if_parameter_names_conflict(inbound_names, parameter_map.keys())
261+
new_op = _SUBS(parameter_map)
178262

179263
# Include existing parameters in self not set to be replaced.
180264
new_parameter_symbols = {
@@ -192,8 +276,12 @@ def subs(
192276
new_parameter_symbols[p] = symbol_type(p.name)
193277

194278
substituted_symbol_expr = self._symbol_expr.subs(symbol_map)
279+
new_replay = self._qpy_replay.copy()
280+
new_replay.append(new_op)
195281

196-
return ParameterExpression(new_parameter_symbols, substituted_symbol_expr)
282+
return ParameterExpression(
283+
new_parameter_symbols, substituted_symbol_expr, _qpy_replay=new_replay
284+
)
197285

198286
def _raise_if_passed_unknown_parameters(self, parameters):
199287
unknown_parameters = parameters - self.parameters
@@ -231,7 +319,11 @@ def _raise_if_parameter_names_conflict(self, inbound_parameters, outbound_parame
231319
)
232320

233321
def _apply_operation(
234-
self, operation: Callable, other: ParameterValueType, reflected: bool = False
322+
self,
323+
operation: Callable,
324+
other: ParameterValueType,
325+
reflected: bool = False,
326+
op_code: _OPCode = None,
235327
) -> "ParameterExpression":
236328
"""Base method implementing math operations between Parameters and
237329
either a constant or a second ParameterExpression.
@@ -253,7 +345,6 @@ def _apply_operation(
253345
A new expression describing the result of the operation.
254346
"""
255347
self_expr = self._symbol_expr
256-
257348
if isinstance(other, ParameterExpression):
258349
self._raise_if_parameter_names_conflict(other._names)
259350
parameter_symbols = {**self._parameter_symbols, **other._parameter_symbols}
@@ -266,10 +357,14 @@ def _apply_operation(
266357

267358
if reflected:
268359
expr = operation(other_expr, self_expr)
360+
new_op = _INSTRUCTION(op_code, other, self)
269361
else:
270362
expr = operation(self_expr, other_expr)
363+
new_op = _INSTRUCTION(op_code, self, other)
364+
new_replay = self._qpy_replay.copy()
365+
new_replay.append(new_op)
271366

272-
out_expr = ParameterExpression(parameter_symbols, expr)
367+
out_expr = ParameterExpression(parameter_symbols, expr, _qpy_replay=new_replay)
273368
out_expr._name_map = self._names.copy()
274369
if isinstance(other, ParameterExpression):
275370
out_expr._names.update(other._names.copy())
@@ -313,81 +408,86 @@ def gradient(self, param) -> Union["ParameterExpression", complex]:
313408
return float(expr_grad)
314409

315410
def __add__(self, other):
316-
return self._apply_operation(operator.add, other)
411+
return self._apply_operation(operator.add, other, op_code=_OPCode.ADD)
317412

318413
def __radd__(self, other):
319-
return self._apply_operation(operator.add, other, reflected=True)
414+
return self._apply_operation(operator.add, other, reflected=True, op_code=_OPCode.ADD)
320415

321416
def __sub__(self, other):
322-
return self._apply_operation(operator.sub, other)
417+
return self._apply_operation(operator.sub, other, op_code=_OPCode.SUB)
323418

324419
def __rsub__(self, other):
325-
return self._apply_operation(operator.sub, other, reflected=True)
420+
return self._apply_operation(operator.sub, other, reflected=True, op_code=_OPCode.SUB)
326421

327422
def __mul__(self, other):
328-
return self._apply_operation(operator.mul, other)
423+
return self._apply_operation(operator.mul, other, op_code=_OPCode.MUL)
329424

330425
def __pos__(self):
331-
return self._apply_operation(operator.mul, 1)
426+
return self._apply_operation(operator.mul, 1, op_code=_OPCode.MUL)
332427

333428
def __neg__(self):
334-
return self._apply_operation(operator.mul, -1)
429+
return self._apply_operation(operator.mul, -1, op_code=_OPCode.MUL)
335430

336431
def __rmul__(self, other):
337-
return self._apply_operation(operator.mul, other, reflected=True)
432+
return self._apply_operation(operator.mul, other, reflected=True, op_code=_OPCode.MUL)
338433

339434
def __truediv__(self, other):
340435
if other == 0:
341436
raise ZeroDivisionError("Division of a ParameterExpression by zero.")
342-
return self._apply_operation(operator.truediv, other)
437+
return self._apply_operation(operator.truediv, other, op_code=_OPCode.DIV)
343438

344439
def __rtruediv__(self, other):
345-
return self._apply_operation(operator.truediv, other, reflected=True)
440+
return self._apply_operation(operator.truediv, other, reflected=True, op_code=_OPCode.DIV)
346441

347442
def __pow__(self, other):
348-
return self._apply_operation(pow, other)
443+
return self._apply_operation(pow, other, op_code=_OPCode.POW)
349444

350445
def __rpow__(self, other):
351-
return self._apply_operation(pow, other, reflected=True)
352-
353-
def _call(self, ufunc):
354-
return ParameterExpression(self._parameter_symbols, ufunc(self._symbol_expr))
446+
return self._apply_operation(pow, other, reflected=True, op_code=_OPCode.POW)
447+
448+
def _call(self, ufunc, op_code):
449+
new_op = _INSTRUCTION(op_code, self)
450+
new_replay = self._qpy_replay.copy()
451+
new_replay.append(new_op)
452+
return ParameterExpression(
453+
self._parameter_symbols, ufunc(self._symbol_expr), _qpy_replay=new_replay
454+
)
355455

356456
def sin(self):
357457
"""Sine of a ParameterExpression"""
358-
return self._call(symengine.sin)
458+
return self._call(symengine.sin, op_code=_OPCode.SIN)
359459

360460
def cos(self):
361461
"""Cosine of a ParameterExpression"""
362-
return self._call(symengine.cos)
462+
return self._call(symengine.cos, op_code=_OPCode.COS)
363463

364464
def tan(self):
365465
"""Tangent of a ParameterExpression"""
366-
return self._call(symengine.tan)
466+
return self._call(symengine.tan, op_code=_OPCode.TAN)
367467

368468
def arcsin(self):
369469
"""Arcsin of a ParameterExpression"""
370-
return self._call(symengine.asin)
470+
return self._call(symengine.asin, op_code=_OPCode.ASIN)
371471

372472
def arccos(self):
373473
"""Arccos of a ParameterExpression"""
374-
return self._call(symengine.acos)
474+
return self._call(symengine.acos, op_code=_OPCode.ACOS)
375475

376476
def arctan(self):
377477
"""Arctan of a ParameterExpression"""
378-
return self._call(symengine.atan)
478+
return self._call(symengine.atan, op_code=_OPCode.ATAN)
379479

380480
def exp(self):
381481
"""Exponential of a ParameterExpression"""
382-
return self._call(symengine.exp)
482+
return self._call(symengine.exp, op_code=_OPCode.EXP)
383483

384484
def log(self):
385485
"""Logarithm of a ParameterExpression"""
386-
return self._call(symengine.log)
486+
return self._call(symengine.log, op_code=_OPCode.LOG)
387487

388488
def sign(self):
389489
"""Sign of a ParameterExpression"""
390-
return self._call(symengine.sign)
490+
return self._call(symengine.sign, op_code=_OPCode.SIGN)
391491

392492
def __repr__(self):
393493
return f"{self.__class__.__name__}({str(self)})"
@@ -455,7 +555,7 @@ def __deepcopy__(self, memo=None):
455555

456556
def __abs__(self):
457557
"""Absolute of a ParameterExpression"""
458-
return self._call(symengine.Abs)
558+
return self._call(symengine.Abs, _OPCode.ABS)
459559

460560
def abs(self):
461561
"""Absolute of a ParameterExpression"""

0 commit comments

Comments
 (0)