Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove validation methods from primitive base classes (backport #11052) #11532

Merged
merged 3 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 10 additions & 29 deletions qiskit/primitives/base/base_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,10 @@
from qiskit.providers import JobV1 as Job
from qiskit.quantum_info.operators import SparsePauliOp
from qiskit.quantum_info.operators.base_operator import BaseOperator
from qiskit.utils.deprecation import deprecate_func

from ..utils import init_observable
from .base_primitive import BasePrimitive
from . import validation

if typing.TYPE_CHECKING:
from qiskit.opflow import PauliSumOp
Expand Down Expand Up @@ -175,18 +176,11 @@ def run(
TypeError: Invalid argument type given.
ValueError: Invalid argument values given.
"""
# Singular validation
circuits = self._validate_circuits(circuits)
observables = self._validate_observables(observables)
parameter_values = self._validate_parameter_values(
parameter_values,
default=[()] * len(circuits),
# Validation
circuits, observables, parameter_values = validation._validate_estimator_args(
circuits, observables, parameter_values
)

# Cross-validation
self._cross_validate_circuits_parameter_values(circuits, parameter_values)
self._cross_validate_circuits_observables(circuits, observables)

# Options
run_opts = copy(self.options)
run_opts.update_options(**run_options)
Expand All @@ -206,34 +200,21 @@ def _run(
parameter_values: tuple[tuple[float, ...], ...],
**run_options,
) -> T:
raise NotImplementedError("The subclass of BaseEstimator must implment `_run` method.")
raise NotImplementedError("The subclass of BaseEstimator must implement `_run` method.")

@staticmethod
@deprecate_func(since="0.46.0")
def _validate_observables(
observables: Sequence[BaseOperator | PauliSumOp | str] | BaseOperator | PauliSumOp | str,
) -> tuple[SparsePauliOp, ...]:
if isinstance(observables, str) or not isinstance(observables, Sequence):
observables = (observables,)
if len(observables) == 0:
raise ValueError("No observables were provided.")
return tuple(init_observable(obs) for obs in observables)
return validation._validate_observables(observables)

@staticmethod
@deprecate_func(since="0.46.0")
def _cross_validate_circuits_observables(
circuits: tuple[QuantumCircuit, ...], observables: tuple[BaseOperator | PauliSumOp, ...]
) -> None:
if len(circuits) != len(observables):
raise ValueError(
f"The number of circuits ({len(circuits)}) does not match "
f"the number of observables ({len(observables)})."
)
for i, (circuit, observable) in enumerate(zip(circuits, observables)):
if circuit.num_qubits != observable.num_qubits:
raise ValueError(
f"The number of qubits of the {i}-th circuit ({circuit.num_qubits}) does "
f"not match the number of qubits of the {i}-th observable "
f"({observable.num_qubits})."
)
return validation._cross_validate_circuits_observables(circuits, observables)

@property
def circuits(self) -> tuple[QuantumCircuit, ...]:
Expand Down
79 changes: 11 additions & 68 deletions qiskit/primitives/base/base_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
from abc import ABC
from collections.abc import Sequence

import numpy as np

from qiskit.circuit import QuantumCircuit
from qiskit.providers import Options
from qiskit.utils.deprecation import deprecate_func

from . import validation


class BasePrimitive(ABC):
Expand Down Expand Up @@ -49,83 +50,25 @@ def set_options(self, **fields):
self._run_options.update_options(**fields)

@staticmethod
@deprecate_func(since="0.46.0")
def _validate_circuits(
circuits: Sequence[QuantumCircuit] | QuantumCircuit,
) -> tuple[QuantumCircuit, ...]:
if isinstance(circuits, QuantumCircuit):
circuits = (circuits,)
elif not isinstance(circuits, Sequence) or not all(
isinstance(cir, QuantumCircuit) for cir in circuits
):
raise TypeError("Invalid circuits, expected Sequence[QuantumCircuit].")
elif not isinstance(circuits, tuple):
circuits = tuple(circuits)
if len(circuits) == 0:
raise ValueError("No circuits were provided.")
return circuits
return validation._validate_circuits(circuits)

@staticmethod
@deprecate_func(since="0.46.0")
def _validate_parameter_values(
parameter_values: Sequence[Sequence[float]] | Sequence[float] | float | None,
default: Sequence[Sequence[float]] | Sequence[float] | None = None,
) -> tuple[tuple[float, ...], ...]:
# Allow optional (if default)
if parameter_values is None:
if default is None:
raise ValueError("No default `parameter_values`, optional input disallowed.")
parameter_values = default

# Support numpy ndarray
if isinstance(parameter_values, np.ndarray):
parameter_values = parameter_values.tolist()
elif isinstance(parameter_values, Sequence):
parameter_values = tuple(
vector.tolist() if isinstance(vector, np.ndarray) else vector
for vector in parameter_values
)

# Allow single value
if _isreal(parameter_values):
parameter_values = ((parameter_values,),)
elif isinstance(parameter_values, Sequence) and not any(
isinstance(vector, Sequence) for vector in parameter_values
):
parameter_values = (parameter_values,)

# Validation
if (
not isinstance(parameter_values, Sequence)
or not all(isinstance(vector, Sequence) for vector in parameter_values)
or not all(all(_isreal(value) for value in vector) for vector in parameter_values)
):
raise TypeError("Invalid parameter values, expected Sequence[Sequence[float]].")

return tuple(tuple(float(value) for value in vector) for vector in parameter_values)
return validation._validate_parameter_values(parameter_values, default=default)

@staticmethod
@deprecate_func(since="0.46.0")
def _cross_validate_circuits_parameter_values(
circuits: tuple[QuantumCircuit, ...], parameter_values: tuple[tuple[float, ...], ...]
) -> None:
if len(circuits) != len(parameter_values):
raise ValueError(
f"The number of circuits ({len(circuits)}) does not match "
f"the number of parameter value sets ({len(parameter_values)})."
)
for i, (circuit, vector) in enumerate(zip(circuits, parameter_values)):
if len(vector) != circuit.num_parameters:
raise ValueError(
f"The number of values ({len(vector)}) does not match "
f"the number of parameters ({circuit.num_parameters}) for the {i}-th circuit."
)


def _isint(obj: Sequence[Sequence[float]] | Sequence[float] | float) -> bool:
"""Check if object is int."""
int_types = (int, np.integer)
return isinstance(obj, int_types) and not isinstance(obj, bool)


def _isreal(obj: Sequence[Sequence[float]] | Sequence[float] | float) -> bool:
"""Check if object is a real number: int or float except ``±Inf`` and ``NaN``."""
float_types = (float, np.floating)
return _isint(obj) or isinstance(obj, float_types) and float("-Inf") < obj < float("Inf")
return validation._cross_validate_circuits_parameter_values(
circuits, parameter_values=parameter_values
)
44 changes: 8 additions & 36 deletions qiskit/primitives/base/base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,13 @@
from copy import copy
from typing import Generic, TypeVar

from qiskit.circuit import ControlFlowOp, Measure, QuantumCircuit
from qiskit.utils.deprecation import deprecate_func
from qiskit.circuit import QuantumCircuit
from qiskit.circuit.parametertable import ParameterView
from qiskit.providers import JobV1 as Job

from .base_primitive import BasePrimitive
from . import validation

T = TypeVar("T", bound=Job)

Expand Down Expand Up @@ -130,15 +132,8 @@ def run(
Raises:
ValueError: Invalid arguments are given.
"""
# Singular validation
circuits = self._validate_circuits(circuits)
parameter_values = self._validate_parameter_values(
parameter_values,
default=[()] * len(circuits),
)

# Cross-validation
self._cross_validate_circuits_parameter_values(circuits, parameter_values)
# Validation
circuits, parameter_values = validation._validate_sampler_args(circuits, parameter_values)

# Options
run_opts = copy(self.options)
Expand All @@ -157,27 +152,15 @@ def _run(
parameter_values: tuple[tuple[float, ...], ...],
**run_options,
) -> T:
raise NotImplementedError("The subclass of BaseSampler must implment `_run` method.")
raise NotImplementedError("The subclass of BaseSampler must implement `_run` method.")

@classmethod
@deprecate_func(since="0.46.0")
def _validate_circuits(
cls,
circuits: Sequence[QuantumCircuit] | QuantumCircuit,
) -> tuple[QuantumCircuit, ...]:
circuits = super()._validate_circuits(circuits)
for i, circuit in enumerate(circuits):
if circuit.num_clbits == 0:
raise ValueError(
f"The {i}-th circuit does not have any classical bit. "
"Sampler requires classical bits, plus measurements "
"on the desired qubits."
)
if not _has_measure(circuit):
raise ValueError(
f"The {i}-th circuit does not have Measure instruction. "
"Without measurements, the circuit cannot be sampled from."
)
return circuits
return validation._validate_circuits(circuits, requires_measure=True)

@property
def circuits(self) -> tuple[QuantumCircuit, ...]:
Expand All @@ -196,14 +179,3 @@ def parameters(self) -> tuple[ParameterView, ...]:
List of the parameters in each quantum circuit.
"""
return tuple(self._parameters)


def _has_measure(circuit: QuantumCircuit) -> bool:
for instruction in reversed(circuit):
if isinstance(instruction.operation, Measure):
return True
elif isinstance(instruction.operation, ControlFlowOp):
for block in instruction.operation.blocks:
if _has_measure(block):
return True
return False
Loading
Loading