diff --git a/mpqp/execution/connection/aws_connection.py b/mpqp/execution/connection/aws_connection.py index 779a6797..2bb10f9e 100644 --- a/mpqp/execution/connection/aws_connection.py +++ b/mpqp/execution/connection/aws_connection.py @@ -356,7 +356,11 @@ def get_aws_braket_account_info() -> str: return result -def get_braket_device(device: AWSDevice, is_noisy: bool = False) -> "BraketDevice": +def get_braket_device( + device: AWSDevice, + is_noisy: bool = False, + is_gate_model: bool = True, +) -> "BraketDevice": """Returns the AwsDevice device associate with the AWSDevice in parameter. Args: @@ -378,13 +382,19 @@ def get_braket_device(device: AWSDevice, is_noisy: bool = False) -> "BraketDevic """ from braket.devices import LocalSimulator + from mpqp.tools.errors import ( + AWSBraketRemoteExecutionError, + DeviceJobIncompatibleError, + ) + if not device.is_remote(): if is_noisy: return LocalSimulator("braket_dm") else: return LocalSimulator() - import pkg_resources + from importlib.metadata import PackageNotFoundError, version + from botocore.exceptions import NoRegionError from braket.aws import AwsDevice, AwsSession @@ -393,11 +403,17 @@ def get_braket_device(device: AWSDevice, is_noisy: bool = False) -> "BraketDevic braket_client = boto3.client("braket", region_name=device.get_region()) aws_session = AwsSession(braket_client=braket_client) - mpqp_version = pkg_resources.get_distribution("mpqp").version[:3] + + try: + mpqp_version = version("mpqp") + except PackageNotFoundError: + mpqp_version = "0.0.0+unknown" + aws_session.add_braket_user_agent( user_agent="APN/1.0 ColibriTD/1.0 MPQP/" + mpqp_version ) - return AwsDevice(device.get_arn(), aws_session=aws_session) + braket_device = AwsDevice(device.get_arn(), aws_session=aws_session) + except ValueError as ve: raise AWSBraketRemoteExecutionError( "Failed to retrieve remote AWS device. Please check the arn, or if the " @@ -410,6 +426,23 @@ def get_braket_device(device: AWSDevice, is_noisy: bool = False) -> "BraketDevic "\nTrace: " + str(err) ) + if is_gate_model: + actions = getattr(getattr(braket_device, "properties", None), "action", None) + if actions is not None: + supported = [getattr(k, "value", str(k)) for k in actions.keys()] + supports_gate_model = any( + ("openqasm" in action.lower()) or ("jaqcd" in action.lower()) + for action in supported + ) + if not supports_gate_model: + raise DeviceJobIncompatibleError( + f"{device.name} does not support gate-model workloads. " + f"Supported Braket action types: {supported}. " + "This is an AHS device, which cannot run MPQP QCircuit." + ) + + return braket_device + def get_all_task_ids() -> list[str]: """Retrieves all the task ids of this account/group from AWS. diff --git a/mpqp/execution/providers/aws.py b/mpqp/execution/providers/aws.py index b74e35f0..ab6009ef 100644 --- a/mpqp/execution/providers/aws.py +++ b/mpqp/execution/providers/aws.py @@ -16,7 +16,11 @@ from mpqp.execution.job import Job, JobStatus, JobType from mpqp.execution.result import Result, Sample, StateVector from mpqp.noise.noise_model import NoiseModel -from mpqp.tools.errors import AWSBraketRemoteExecutionError, DeviceJobIncompatibleError +from mpqp.tools.errors import ( + AWSBraketRemoteExecutionError, + DeviceJobIncompatibleError, + DeviceJobIncompatibleWarning, +) if TYPE_CHECKING: from braket.circuits import Circuit @@ -109,16 +113,31 @@ def run_braket(job: Job) -> Result: f"{job.device} instead" ) + import warnings + from braket.tasks import GateModelQuantumTaskResult - if isinstance(job.measure, ExpectationMeasure): - return run_braket_observable(job) - _, task = submit_job_braket(job) - res = task.result() - if TYPE_CHECKING: - assert isinstance(res, GateModelQuantumTaskResult) + try: + if isinstance(job.measure, ExpectationMeasure): + return run_braket_observable(job) + + _, task = submit_job_braket(job) + res = task.result() + if TYPE_CHECKING: + assert isinstance(res, GateModelQuantumTaskResult) - return extract_result(res, job, job.device) + return extract_result(res, job, job.device) + + except DeviceJobIncompatibleError as e: + warnings.warn(str(e), DeviceJobIncompatibleWarning, stacklevel=5) + + job.status = JobStatus.ERROR + return Result( + job, + data=None, + errors="Unsupported Braket backend for QCircuit (see warning).", + shots=0, + ) def run_braket_observable(job: Job): @@ -151,6 +170,7 @@ def run_braket_observable(job: Job): job.device, is_noisy=bool(job.circuit.noises), ) + if job.measure is None: raise NotImplementedError("job.measure is None") assert isinstance(job.measure, ExpectationMeasure) @@ -270,7 +290,7 @@ def run_braket_observable(job: Job): ) if braket_sum is not None: - from braket.program_sets import ProgramSet, CircuitBinding + from braket.program_sets import CircuitBinding, ProgramSet from braket.tasks.program_set_quantum_task_result import ( ProgramSetQuantumTaskResult, ) diff --git a/mpqp/execution/result.py b/mpqp/execution/result.py index 4f679741..38d1e747 100644 --- a/mpqp/execution/result.py +++ b/mpqp/execution/result.py @@ -30,7 +30,7 @@ import numpy.typing as npt from mpqp.core.instruction.measurement.basis_measure import BasisMeasure -from mpqp.execution import Job, JobType +from mpqp.execution import Job, JobStatus, JobType from mpqp.execution.devices import AvailableDevice from mpqp.tools.display import clean_1D_array, clean_number_repr from mpqp.tools.errors import ResultAttributeError @@ -288,8 +288,8 @@ class Result: def __init__( self, job: Job, - data: float | dict["str", float] | StateVector | list[Sample], - errors: Optional[float | dict[Any, Any]] = None, + data: float | dict["str", float] | StateVector | list[Sample] | None, + errors: Optional[float | dict[Any, Any] | str] = None, shots: int = 0, ): self.job = job @@ -305,6 +305,11 @@ def __init__( """See parameter description.""" self._data = data + if data is None: + if job.status != JobStatus.ERROR: + raise TypeError("Result data cannot be None unless job.status == ERROR") + return + # depending on the type of job, fills the result info from the data in parameter if job.job_type == JobType.OBSERVABLE: if not isinstance(data, float) and not isinstance(data, dict): @@ -458,6 +463,9 @@ def __str__(self): label = "" if self.job.circuit.label is None else self.job.circuit.label + ", " header = f"Result: {label}{type(self.device).__name__}, {self.device.name}" + if self.job.status == JobStatus.ERROR: + return f"{header}\n Error: {self.error}" + if self.job.job_type == JobType.SAMPLE: measures = self.job.circuit.measurements if not len(measures) == 1: diff --git a/mpqp/tools/errors.py b/mpqp/tools/errors.py index f44b357e..54a14154 100644 --- a/mpqp/tools/errors.py +++ b/mpqp/tools/errors.py @@ -34,6 +34,10 @@ class DeviceJobIncompatibleError(ValueError): for the selected device (for example SAMPLE job on a statevector simulator).""" +class DeviceJobIncompatibleWarning(UserWarning): + """A warning is issued when a job is not compatible with the selected device.""" + + class RemoteExecutionError(ConnectionError): """Raised when an error occurred during a remote connection, submission or execution."""