diff --git a/cirq-google/cirq_google/engine/abstract_job.py b/cirq-google/cirq_google/engine/abstract_job.py index dc4e3a8d193..a641349f87c 100644 --- a/cirq-google/cirq_google/engine/abstract_job.py +++ b/cirq-google/cirq_google/engine/abstract_job.py @@ -141,9 +141,16 @@ def failure(self) -> tuple[str, str] | None: """Return failure code and message of the job if present.""" @abc.abstractmethod - def get_repetitions_and_sweeps(self) -> tuple[int, list[cirq.Sweep]]: + def get_repetitions_and_sweeps( + self, circuit_num: int | None = None + ) -> tuple[int, list[cirq.Sweep]]: """Returns the repetitions and sweeps for the job. + Args: + circuit_num: if this is a batch job, the index of the circuit + to return the sweeps for. This argument is zero-indexed. + Negative values index from the end of the list. + Returns: A tuple of the repetition count and list of sweeps. """ @@ -159,11 +166,11 @@ def get_calibration(self) -> calibration.Calibration | None: one was captured, else None.""" @abc.abstractmethod - def get_circuit(self, program_num: int | None = None) -> cirq.Circuit: + def get_circuit(self, circuit_num: int | None = None) -> cirq.Circuit: """Returns the cirq Circuit for the job. Args: - program_num: if this is a multi-circuit job, the index of the circuit + circuit_num: if this is a multi-circuit job, the index of the circuit to return. This argument is zero-indexed. Negative values index from the end of the list. Ignored if not multi-circuit. diff --git a/cirq-google/cirq_google/engine/abstract_job_test.py b/cirq-google/cirq_google/engine/abstract_job_test.py index 179aa86df91..822f1f61415 100644 --- a/cirq-google/cirq_google/engine/abstract_job_test.py +++ b/cirq-google/cirq_google/engine/abstract_job_test.py @@ -80,7 +80,7 @@ def get_processor(self): def get_calibration(self): pass - def get_circuit(self, program_num: int | None = None) -> cirq.Circuit: + def get_circuit(self, circuit_num: int | None = None) -> cirq.Circuit: return cirq.Circuit() def cancel(self) -> None: diff --git a/cirq-google/cirq_google/engine/abstract_local_job.py b/cirq-google/cirq_google/engine/abstract_local_job.py index d3fc5b0b8db..19ffca7e9f9 100644 --- a/cirq-google/cirq_google/engine/abstract_local_job.py +++ b/cirq-google/cirq_google/engine/abstract_local_job.py @@ -157,13 +157,48 @@ def processor_ids(self) -> list[str]: """Returns the processor ids provided when the job was created.""" return [self._processor_id] - def get_repetitions_and_sweeps(self) -> tuple[int, list[cirq.Sweep]]: + def get_repetitions_and_sweeps( + self, circuit_num: int | None = None + ) -> tuple[int, list[cirq.Sweep]]: """Returns the repetitions and sweeps for the job. + Args: + circuit_num: if this is a batch job, the index of the circuit + to return the sweeps for. This argument is zero-indexed. + Negative values index from the end of the list. + Returns: A tuple of the repetition count and list of sweeps. """ - return (self._repetitions, self._sweeps) + is_batch = self.program().is_batch() + batch_size = self.program().batch_size() if is_batch else 1 + + is_mapped = is_batch and len(self._sweeps) == batch_size and len(self._sweeps) > 1 + + if circuit_num is None: + if is_mapped: + raise ValueError( + f"This is a batch job with {len(self._sweeps)} mapped sweeps. " + "Please specify `circuit_num` to get sweeps for a specific circuit." + ) + return (self._repetitions, self._sweeps) + + if not is_batch: + if circuit_num != 0 and circuit_num != -1: + raise IndexError(f"Job is not a batch job, cannot index {circuit_num}") + return (self._repetitions, self._sweeps) + + if not is_mapped: + # Shared sweeps in a batch job, return all of them + return (self._repetitions, self._sweeps) + + # Mapped sweeps in a batch job + try: + return (self._repetitions, [self._sweeps[circuit_num]]) + except IndexError: + raise IndexError( + f"Index {circuit_num} out of range for sweeps of size {len(self._sweeps)}." + ) def get_processor(self) -> AbstractLocalProcessor: """Returns the AbstractProcessor for the processor the job is/was run on, @@ -175,15 +210,15 @@ def get_calibration(self) -> calibration.Calibration | None: from the parent Engine object.""" return self.get_processor().get_latest_calibration(int(self._create_time.timestamp())) - def get_circuit(self, program_num: int | None = None) -> cirq.Circuit: + def get_circuit(self, circuit_num: int | None = None) -> cirq.Circuit: """Returns the cirq Circuit for the job. Args: - program_num: if this is a multi-circuit job, the index of the circuit + circuit_num: if this is a multi-circuit job, the index of the circuit to return. This argument is zero-indexed. Negative values index from the end of the list. Ignored if not multi-circuit. Returns: The job's cirq Circuit. """ - return self.program().get_circuit(program_num) + return self.program().get_circuit(circuit_num) diff --git a/cirq-google/cirq_google/engine/abstract_local_job_test.py b/cirq-google/cirq_google/engine/abstract_local_job_test.py index 38fccc69971..cf3bd338106 100644 --- a/cirq-google/cirq_google/engine/abstract_local_job_test.py +++ b/cirq-google/cirq_google/engine/abstract_local_job_test.py @@ -21,6 +21,8 @@ from typing import TYPE_CHECKING from unittest import mock +import pytest + import cirq from cirq_google.cloud import quantum from cirq_google.engine.abstract_local_job import AbstractLocalJob @@ -75,14 +77,52 @@ def test_description_and_labels(): def test_reps_and_sweeps(): + # Single program (non-batch) + mock_program = mock.Mock() + mock_program.is_batch.return_value = False job = NothingJob( job_id='test', processor_id='grill', - parent_program=None, + parent_program=mock_program, repetitions=100, sweeps=[cirq.Linspace('t', 0, 10, 0.1)], ) assert job.get_repetitions_and_sweeps() == (100, [cirq.Linspace('t', 0, 10, 0.1)]) + assert job.get_repetitions_and_sweeps(0) == (100, [cirq.Linspace('t', 0, 10, 0.1)]) + with pytest.raises(IndexError, match="Job is not a batch job"): + _ = job.get_repetitions_and_sweeps(1) + + # Batch program, shared sweep + mock_program_batch = mock.Mock() + mock_program_batch.is_batch.return_value = True + mock_program_batch.batch_size.return_value = 2 + job_batch_shared = NothingJob( + job_id='test', + processor_id='grill', + parent_program=mock_program_batch, + repetitions=100, + sweeps=[cirq.Linspace('t', 0, 10, 0.1)], + ) + # Shared sweep, works with None or any index + assert job_batch_shared.get_repetitions_and_sweeps() == (100, [cirq.Linspace('t', 0, 10, 0.1)]) + assert job_batch_shared.get_repetitions_and_sweeps(0) == (100, [cirq.Linspace('t', 0, 10, 0.1)]) + assert job_batch_shared.get_repetitions_and_sweeps(1) == (100, [cirq.Linspace('t', 0, 10, 0.1)]) + + # Batch program, mapped sweeps + job_batch_mapped = NothingJob( + job_id='test', + processor_id='grill', + parent_program=mock_program_batch, + repetitions=100, + sweeps=[cirq.Linspace('t', 0, 10, 0.1), cirq.Linspace('u', 0, 5, 0.5)], + ) + # Mapped sweeps, requires index + with pytest.raises(ValueError, match="mapped sweeps"): + _ = job_batch_mapped.get_repetitions_and_sweeps() + assert job_batch_mapped.get_repetitions_and_sweeps(0) == (100, [cirq.Linspace('t', 0, 10, 0.1)]) + assert job_batch_mapped.get_repetitions_and_sweeps(1) == (100, [cirq.Linspace('u', 0, 5, 0.5)]) + with pytest.raises(IndexError, match="Index 2 out of range"): + _ = job_batch_mapped.get_repetitions_and_sweeps(2) def test_create_update_time(): diff --git a/cirq-google/cirq_google/engine/abstract_local_program.py b/cirq-google/cirq_google/engine/abstract_local_program.py index 685a49fb45e..7b91e0ea6a2 100644 --- a/cirq-google/cirq_google/engine/abstract_local_program.py +++ b/cirq-google/cirq_google/engine/abstract_local_program.py @@ -190,22 +190,47 @@ def remove_labels(self, keys: list[str]) -> AbstractProgram: del self._labels[key] return self - def get_circuit(self, program_num: int | None = None) -> cirq.Circuit: + def get_circuit(self, circuit_num: int | None = None) -> cirq.Circuit: """Returns the cirq Circuit for the program. This is only supported if the program was created with the V2 protos. Args: - program_num: if this is a multi-circuit program, the index of the circuit + circuit_num: if this is a multi-circuit program, the index of the circuit to return. This argument is zero-indexed. Negative values indexing from the end of the list. Returns: The program's cirq Circuit. """ - if program_num is not None: - return self._circuits[program_num] - return self._circuits[0] + if circuit_num is None: + if self.is_batch(): + raise ValueError( + f"This program is a batch program containing {len(self._circuits)} circuits. " + "Please specify `circuit_num` to get a specific circuit, " + "or use `get_circuits()` to get all of them." + ) + return self._circuits[0] + try: + return self._circuits[circuit_num] + except IndexError: + raise IndexError( + f"Index {circuit_num} out of range for batch program of size {len(self._circuits)}." + ) + + def get_circuits(self) -> list[cirq.Circuit]: + """Returns all the cirq Circuits for the program.""" + return self._circuits + + def is_batch(self) -> bool: + """Returns True if the program is a batch program.""" + return len(self._circuits) > 1 def batch_size(self) -> int: - """Returns the number of programs in a batch program.""" + """Returns the number of programs in a batch program. + + Raises: + ValueError: if the program created was not a batch program. + """ + if not self.is_batch(): + raise ValueError("This program is not a batch program.") return len(self._circuits) diff --git a/cirq-google/cirq_google/engine/abstract_local_program_test.py b/cirq-google/cirq_google/engine/abstract_local_program_test.py index d658618ab6b..19a1a288cc0 100644 --- a/cirq-google/cirq_google/engine/abstract_local_program_test.py +++ b/cirq-google/cirq_google/engine/abstract_local_program_test.py @@ -149,12 +149,25 @@ def test_description_and_labels(): def test_circuit(): circuit1 = cirq.Circuit(cirq.X(cirq.LineQubit(1))) circuit2 = cirq.Circuit(cirq.Y(cirq.LineQubit(2))) + + # Single circuit, non-batch program = NothingProgram([circuit1], None) - assert program.batch_size() == 1 + assert not program.is_batch() + with pytest.raises(ValueError, match="not a batch program"): + _ = program.batch_size() assert program.get_circuit() == circuit1 assert program.get_circuit(0) == circuit1 - assert program.batch_size() == 1 + assert program.get_circuits() == [circuit1] + + # Multi circuit (always batch) program = NothingProgram([circuit1, circuit2], None) + assert program.is_batch() assert program.batch_size() == 2 + with pytest.raises(ValueError, match="batch program containing 2 circuits"): + _ = program.get_circuit() assert program.get_circuit(0) == circuit1 assert program.get_circuit(1) == circuit2 + assert program.get_circuits() == [circuit1, circuit2] + + with pytest.raises(IndexError): + _ = program.get_circuit(2) diff --git a/cirq-google/cirq_google/engine/abstract_program.py b/cirq-google/cirq_google/engine/abstract_program.py index d703c71a4fc..e0ed7519e2c 100644 --- a/cirq-google/cirq_google/engine/abstract_program.py +++ b/cirq-google/cirq_google/engine/abstract_program.py @@ -158,12 +158,12 @@ def remove_labels(self, keys: list[str]) -> AbstractProgram: """ @abc.abstractmethod - def get_circuit(self, program_num: int | None = None) -> cirq.Circuit: + def get_circuit(self, circuit_num: int | None = None) -> cirq.Circuit: """Returns the cirq Circuit for the program. This is only supported if the program was created with the V2 protos. Args: - program_num: if this is a multi-circuit program, the index of the circuit + circuit_num: if this is a multi-circuit program, the index of the circuit to return. This argument is zero-indexed. Negative values indexing from the end of the list. @@ -171,6 +171,18 @@ def get_circuit(self, program_num: int | None = None) -> cirq.Circuit: The program's cirq Circuit. """ + @abc.abstractmethod + def get_circuits(self) -> list[cirq.Circuit]: + """Returns all the cirq Circuits for the program. + + Returns: + A list of the program's cirq Circuits. + """ + + @abc.abstractmethod + def is_batch(self) -> bool: + """Returns True if the program is a batch program.""" + @abc.abstractmethod def batch_size(self) -> int: """Returns the number of programs in a batch program. diff --git a/cirq-google/cirq_google/engine/engine_job.py b/cirq-google/cirq_google/engine/engine_job.py index 1746789ee67..a2111e512e2 100644 --- a/cirq-google/cirq_google/engine/engine_job.py +++ b/cirq-google/cirq_google/engine/engine_job.py @@ -226,15 +226,52 @@ def failure(self) -> tuple[str, str] | None: return (failure.error_code.name, failure.error_message) return None - def get_repetitions_and_sweeps(self) -> tuple[int, list[cirq.Sweep]]: + def get_repetitions_and_sweeps( + self, circuit_num: int | None = None + ) -> tuple[int, list[cirq.Sweep]]: """Returns the repetitions and sweeps for the Quantum Engine job. + Args: + circuit_num: if this is a batch job, the index of the circuit + to return the sweeps for. This argument is zero-indexed. + Negative values index from the end of the list. + Returns: A tuple of the repetition count and list of sweeps. """ if self._job is None or self._job.run_context is None: self._job = self._get_job(return_run_context=True) - return _deserialize_run_context(self._job.run_context) + reps, sweeps = _deserialize_run_context(self._job.run_context) + + is_batch = self.program().is_batch() + batch_size = self.program().batch_size() if is_batch else 1 + + is_mapped = is_batch and len(sweeps) == batch_size and len(sweeps) > 1 + + if circuit_num is None: + if is_mapped: + raise ValueError( + f"This is a batch job with {len(sweeps)} mapped sweeps. " + "Please specify `circuit_num` to get sweeps for a specific circuit." + ) + return (reps, sweeps) + + if not is_batch: + if circuit_num != 0 and circuit_num != -1: + raise IndexError(f"Job is not a batch job, cannot index {circuit_num}") + return (reps, sweeps) + + if not is_mapped: + # Shared sweeps in a batch job, return all of them + return (reps, sweeps) + + # Mapped sweeps in a batch job + try: + return (reps, [sweeps[circuit_num]]) + except IndexError: + raise IndexError( + f"Index {circuit_num} out of range for batch job sweeps of size {len(sweeps)}." + ) def get_processor(self) -> engine_processor.EngineProcessor | None: """Returns the EngineProcessor for the processor the job is/was run on, @@ -258,18 +295,18 @@ def get_calibration(self) -> calibration.Calibration | None: metrics = v2.metrics_pb2.MetricsSnapshot.FromString(response.data.value) return calibration.Calibration(metrics) - async def get_circuit_async(self, program_num: int | None = None) -> cirq.Circuit: + async def get_circuit_async(self, circuit_num: int | None = None) -> cirq.Circuit: """Returns the cirq Circuit for the Quantum Engine job. Args: - program_num: if this is a multi-circuit job, the index of the circuit + circuit_num: if this is a multi-circuit job, the index of the circuit to return. This argument is zero-indexed. Negative values indexing from the end of the list. Returns: The job's cirq Circuit. """ - return await self.program().get_circuit_async(program_num) + return await self.program().get_circuit_async(circuit_num) get_circuit = duet.sync(get_circuit_async) diff --git a/cirq-google/cirq_google/engine/engine_job_test.py b/cirq-google/cirq_google/engine/engine_job_test.py index a75593566db..ecaa5939775 100644 --- a/cirq-google/cirq_google/engine/engine_job_test.py +++ b/cirq-google/cirq_google/engine/engine_job_test.py @@ -30,6 +30,110 @@ from cirq_google.engine.engine import EngineContext from cirq_google.engine.stream_manager import StreamError +_PROGRAM_V2 = util.pack_any( + Merge( + """language { + gate_set: "v2_5" + arg_function_language: "exp" +} +circuit { + scheduling_strategy: MOMENT_BY_MOMENT + moments { + operations { + qubit_constant_index: 0 + phasedxpowgate { + phase_exponent { + float_value: 0.0 + } + exponent { + float_value: 0.5 + } + } + } + } + moments { + operations { + qubit_constant_index: 0 + measurementgate { + key { + arg_value { + string_value: "result" + } + } + invert_mask { + arg_value { + bool_values { + } + } + } + } + } + } +} +constants { + qubit { + id: "5_2" + } +} +""", + v2.program_pb2.Program(), + ) +) + +_BATCH_PROGRAM_V2 = util.pack_any( + Merge( + """language { + gate_set: "v2_5" + arg_function_language: "exp" +} +keyed_circuits { + key: "c1" + circuit { + scheduling_strategy: MOMENT_BY_MOMENT + moments { + operations { + qubit_constant_index: 0 + phasedxpowgate { + phase_exponent { + float_value: 0.0 + } + exponent { + float_value: 0.5 + } + } + } + } + } +} +keyed_circuits { + key: "c2" + circuit { + scheduling_strategy: MOMENT_BY_MOMENT + moments { + operations { + qubit_constant_index: 0 + phasedxpowgate { + phase_exponent { + float_value: 0.0 + } + exponent { + float_value: 0.5 + } + } + } + } + } +} +constants { + qubit { + id: "5_2" + } +} +""", + v2.program_pb2.Program(), + ) +) + @pytest.fixture(scope='module', autouse=True) def mock_grpc_client(): @@ -220,9 +324,11 @@ def test_failure_with_no_error(): assert not job.failure() +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') @mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_async') -def test_get_repetitions_and_sweeps(get_job): - job = cg.EngineJob('a', 'b', 'steve', EngineContext()) +def test_get_repetitions_and_sweeps(get_job, get_program): + # Single program (non-batch) + get_program.return_value = quantum.QuantumProgram(code=_PROGRAM_V2) get_job.return_value = quantum.QuantumJob( run_context=util.pack_any( v2.run_context_pb2.RunContext( @@ -230,8 +336,56 @@ def test_get_repetitions_and_sweeps(get_job): ) ) ) + job = cg.EngineJob('a', 'b', 'steve', EngineContext()) assert job.get_repetitions_and_sweeps() == (10, [cirq.UnitSweep]) - get_job.assert_called_once_with('a', 'b', 'steve', True) + assert job.get_repetitions_and_sweeps(0) == (10, [cirq.UnitSweep]) + with pytest.raises(IndexError, match="Job is not a batch job"): + _ = job.get_repetitions_and_sweeps(1) + + # Batch program, shared sweep + get_program.reset_mock() + get_program.return_value = quantum.QuantumProgram(code=_BATCH_PROGRAM_V2) + get_job.reset_mock() + get_job.return_value = quantum.QuantumJob( + run_context=util.pack_any( + v2.run_context_pb2.RunContext( + parameter_sweeps=[v2.run_context_pb2.ParameterSweep(repetitions=10)] + ) + ) + ) + job_batch_shared = cg.EngineJob('a', 'b', 'steve', EngineContext()) + assert job_batch_shared.get_repetitions_and_sweeps() == (10, [cirq.UnitSweep]) + assert job_batch_shared.get_repetitions_and_sweeps(0) == (10, [cirq.UnitSweep]) + assert job_batch_shared.get_repetitions_and_sweeps(1) == (10, [cirq.UnitSweep]) + + # Batch program, mapped sweeps + get_job.reset_mock() + get_job.return_value = quantum.QuantumJob( + run_context=util.pack_any( + v2.run_context_pb2.RunContext( + parameter_sweeps=[ + v2.run_context_pb2.ParameterSweep(repetitions=10), + v2.run_context_pb2.ParameterSweep( + repetitions=10, + sweep=v2.run_context_pb2.Sweep( + single_sweep=v2.run_context_pb2.SingleSweep( + parameter_key='t', + points=v2.run_context_pb2.Points(points_double=[1.0, 2.0]), + ) + ), + ), + ] + ) + ) + ) + job_batch_mapped = cg.EngineJob('a', 'b', 'steve', EngineContext()) + # Mapped sweeps, requires index + with pytest.raises(ValueError, match="mapped sweeps"): + _ = job_batch_mapped.get_repetitions_and_sweeps() + assert job_batch_mapped.get_repetitions_and_sweeps(0) == (10, [cirq.UnitSweep]) + assert job_batch_mapped.get_repetitions_and_sweeps(1) == (10, [cirq.Points('t', [1.0, 2.0])]) + with pytest.raises(IndexError, match="Index 2 out of range"): + _ = job_batch_mapped.get_repetitions_and_sweeps(2) @mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_async') diff --git a/cirq-google/cirq_google/engine/engine_program.py b/cirq-google/cirq_google/engine/engine_program.py index 236b22a7d36..7f6b42cb22e 100644 --- a/cirq-google/cirq_google/engine/engine_program.py +++ b/cirq-google/cirq_google/engine/engine_program.py @@ -62,6 +62,7 @@ def __init__( self.program_id = program_id self.context = context self._program = _program + self._proto: v2.program_pb2.Program | None = None async def run_sweep_async( self, @@ -352,35 +353,87 @@ async def remove_labels_async(self, keys: list[str]) -> EngineProgram: remove_labels = duet.sync(remove_labels_async) - async def get_circuit_async(self, program_num: int | None = None) -> cirq.Circuit: + async def _get_proto_async(self) -> v2.program_pb2.Program: + if self._proto is None: + if self._program is None or not self._program.code or not self._program.code.type_url: + self._program = await self.context.client.get_program_async( + self.project_id, self.program_id, True + ) + self._proto = _deserialize_to_proto(self._program.code) + return self._proto + + async def is_batch_async(self) -> bool: + """Returns True if the program is a batch program.""" + proto = await self._get_proto_async() + return proto.WhichOneof('program') != 'circuit' + + is_batch = duet.sync(is_batch_async) + + async def get_circuit_async(self, circuit_num: int | None = None) -> cirq.Circuit: """Returns the cirq Circuit for the Quantum Engine program. This is only supported if the program was created with the V2 protos. Args: - program_num: if this is a multi-circuit program, the index of the circuit + circuit_num: if this is a multi-circuit program, the index of the circuit to return. This argument is zero-indexed. Negative values indexing from the end of the list. Returns: The program's cirq Circuit. """ - # The code field is an any_pb2.Any and is always set. But if the program has not - # been fetched this field may be empty, which we can see by checking the type_url. - if self._program is None or not self._program.code or not self._program.code.type_url: - self._program = await self.context.client.get_program_async( - self.project_id, self.program_id, True + proto = await self._get_proto_async() + is_batch = await self.is_batch_async() + + if circuit_num is None: + if is_batch: + raise ValueError( + f"Program {self.program_id} is a batch program containing " + f"{len(proto.keyed_circuits)} circuits. " + "Please specify `circuit_num` to get a specific circuit, " + "or use `get_circuits()` to get all of them." + ) + return circuit_serializer.CIRCUIT_SERIALIZER.deserialize(proto) + + if not is_batch: + if circuit_num != 0 and circuit_num != -1: + raise IndexError( + f"Program {self.program_id} is not a batch program, cannot index {circuit_num}" + ) + return circuit_serializer.CIRCUIT_SERIALIZER.deserialize(proto) + + deserialized = circuit_serializer.CIRCUIT_SERIALIZER.deserialize_multi_program(proto) + try: + return cast(cirq.Circuit, deserialized[circuit_num][2]) + except IndexError: + raise IndexError( + f"Index {circuit_num} out of range for batch program {self.program_id} " + f"of size {len(deserialized)}." ) - return _deserialize_program(self._program.code, program_num) get_circuit = duet.sync(get_circuit_async) + async def get_circuits_async(self) -> list[cirq.Circuit]: + """Returns all the cirq Circuits for the Quantum Engine program.""" + proto = await self._get_proto_async() + serializer = circuit_serializer.CIRCUIT_SERIALIZER + if not await self.is_batch_async(): + return [serializer.deserialize(proto)] + + deserialized = serializer.deserialize_multi_program(proto) + return [cast(cirq.Circuit, triple[2]) for triple in deserialized] + + get_circuits = duet.sync(get_circuits_async) + async def batch_size_async(self) -> int: """Returns the number of programs in a batch program. Raises: ValueError: if the program created was not a batch program. """ - raise NotImplementedError("Batch programs are no longer supported.") + proto = await self._get_proto_async() + if not await self.is_batch_async(): + raise ValueError(f"Program {self.program_id} is not a batch program.") + return len(proto.keyed_circuits) batch_size = duet.sync(batch_size_async) @@ -407,19 +460,24 @@ def __str__(self) -> str: return f'EngineProgram(project_id=\'{self.project_id}\', program_id=\'{self.program_id}\')' -def _deserialize_program(code: any_pb2.Any, program_num: int | None = None) -> cirq.Circuit: +def _deserialize_to_proto(code: any_pb2.Any) -> v2.program_pb2.Program: import cirq_google.engine.engine as engine_base code_type = code.type_url[len(engine_base.TYPE_PREFIX) :] - program = None if code_type == 'cirq.google.api.v1.Program' or code_type == 'cirq.api.google.v1.Program': raise ValueError('deserializing a v1 Program is not supported') elif code_type == 'cirq.google.api.v2.Program' or code_type == 'cirq.api.google.v2.Program': - program = v2.program_pb2.Program.FromString(code.value) - if program is not None: - serializer = circuit_serializer.CIRCUIT_SERIALIZER - if program_num is not None: - return cast(cirq.Circuit, serializer.deserialize_multi_program(program)[program_num][2]) - return serializer.deserialize(program) + return v2.program_pb2.Program.FromString(code.value) raise ValueError(f'unsupported program type: {code_type}') + + +def _deserialize_program(code: any_pb2.Any, circuit_num: int | None = None) -> cirq.Circuit: + program = _deserialize_to_proto(code) + serializer = circuit_serializer.CIRCUIT_SERIALIZER + if circuit_num is not None: + try: + return cast(cirq.Circuit, serializer.deserialize_multi_program(program)[circuit_num][2]) + except IndexError: + raise IndexError(f"Index {circuit_num} out of range for batch program.") + return serializer.deserialize(program) diff --git a/cirq-google/cirq_google/engine/engine_program_test.py b/cirq-google/cirq_google/engine/engine_program_test.py index 885ffa3228f..09ebaa5b4e3 100644 --- a/cirq-google/cirq_google/engine/engine_program_test.py +++ b/cirq-google/cirq_google/engine/engine_program_test.py @@ -81,6 +81,61 @@ ) +_BATCH_PROGRAM_V2 = util.pack_any( + Merge( + """language { + gate_set: "v2_5" + arg_function_language: "exp" +} +keyed_circuits { + key: "c1" + circuit { + scheduling_strategy: MOMENT_BY_MOMENT + moments { + operations { + qubit_constant_index: 0 + phasedxpowgate { + phase_exponent { + float_value: 0.0 + } + exponent { + float_value: 0.5 + } + } + } + } + } +} +keyed_circuits { + key: "c2" + circuit { + scheduling_strategy: MOMENT_BY_MOMENT + moments { + operations { + qubit_constant_index: 0 + phasedxpowgate { + phase_exponent { + float_value: 0.0 + } + exponent { + float_value: 0.5 + } + } + } + } + } +} +constants { + qubit { + id: "5_2" + } +} +""", + v2.program_pb2.Program(), + ) +) + + @mock.patch('cirq_google.engine.engine_client.EngineClient.create_job_async') def test_run_sweeps_delegation(create_job_async): create_job_async.return_value = ('steve', quantum.QuantumJob()) @@ -315,29 +370,40 @@ def test_get_circuit_v2(get_program_async, include_empty_program: bool) -> None: ) get_program_async.assert_called_once_with('a', 'b', True) + # Test indexing on a batch program + program_batch = cg.EngineProgram('a', 'b', EngineContext()) + get_program_async.reset_mock() + get_program_async.return_value = quantum.QuantumProgram(code=_BATCH_PROGRAM_V2) with mock.patch( 'cirq_google.serialization.circuit_serializer.CIRCUIT_SERIALIZER.deserialize_multi_program' ) as deserialize_multi_program: deserialize_multi_program.return_value = [('key0', (), circuit), ('key1', (), circuit)] - assert program.get_circuit(program_num=1) is circuit + assert program_batch.get_circuit(circuit_num=1) is circuit deserialize_multi_program.assert_called_once() @duet.sync async def test_get_circuit_async(): context = EngineContext() + circuit = cirq.Circuit(cirq.X(cirq.GridQubit(5, 2)) ** 0.5) + + # Single circuit program = cg.EngineProgram('a', 'b', context) - circuit = cirq.Circuit() with mock.patch.object( context.client, 'get_program_async', new_callable=mock.AsyncMock ) as get_program_async: get_program_async.return_value = quantum.QuantumProgram(code=_PROGRAM_V2) - with mock.patch( - 'cirq_google.engine.engine_program._deserialize_program' - ) as deserialize_program: - deserialize_program.return_value = circuit - assert await program.get_circuit_async(1) == circuit - deserialize_program.assert_called_once_with(mock.ANY, 1) + c = await program.get_circuit_async() + assert isinstance(c, cirq.Circuit) + + # Batch circuit + program_batch = cg.EngineProgram('a', 'b', context) + with mock.patch.object( + context.client, 'get_program_async', new_callable=mock.AsyncMock + ) as get_program_async: + get_program_async.return_value = quantum.QuantumProgram(code=_BATCH_PROGRAM_V2) + c = await program_batch.get_circuit_async(1) + cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(c, circuit) def test_deserialize_program(): @@ -347,10 +413,16 @@ def test_deserialize_program(): ) as mock_serializer: cg.engine.engine_program._deserialize_program(code) mock_serializer.deserialize.assert_called_once() - cg.engine.engine_program._deserialize_program(code, program_num=1) + cg.engine.engine_program._deserialize_program(code, circuit_num=1) mock_serializer.deserialize_multi_program.assert_called_once() +def test_deserialize_program_errors(): + # Index out of range + with pytest.raises(IndexError, match="Index 2 out of range"): + cg.engine.engine_program._deserialize_program(_BATCH_PROGRAM_V2, circuit_num=2) + + @pytest.fixture(scope='module', autouse=True) def mock_grpc_client(): with mock.patch( @@ -390,3 +462,83 @@ def test_delete_jobs(delete_job_async): def test_str(): program = cg.EngineProgram('my-proj', 'my-prog', EngineContext()) assert str(program) == 'EngineProgram(project_id=\'my-proj\', program_id=\'my-prog\')' + + +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') +def test_batch_size(get_program_async): + # Single circuit program (not a batch) + program = cg.EngineProgram('a', 'b', EngineContext()) + get_program_async.return_value = quantum.QuantumProgram(code=_PROGRAM_V2) + with pytest.raises(ValueError, match="not a batch program"): + _ = program.batch_size() + + # Batch program + program_batch = cg.EngineProgram('a', 'b', EngineContext()) + get_program_async.reset_mock() + get_program_async.return_value = quantum.QuantumProgram(code=_BATCH_PROGRAM_V2) + assert program_batch.batch_size() == 2 + get_program_async.assert_called_once_with('a', 'b', True) + + +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') +def test_get_circuits(get_program_async): + circuit = cirq.Circuit( + cirq.X(cirq.GridQubit(5, 2)) ** 0.5, cirq.measure(cirq.GridQubit(5, 2), key='result') + ) + + # Single circuit program + program = cg.EngineProgram('a', 'b', EngineContext()) + get_program_async.return_value = quantum.QuantumProgram(code=_PROGRAM_V2) + circuits = program.get_circuits() + assert len(circuits) == 1 + cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(circuits[0], circuit) + + # Batch program + program_batch = cg.EngineProgram('a', 'b', EngineContext()) + get_program_async.reset_mock() + get_program_async.return_value = quantum.QuantumProgram(code=_BATCH_PROGRAM_V2) + circuits = program_batch.get_circuits() + assert len(circuits) == 2 + expected_circuit = cirq.Circuit(cirq.X(cirq.GridQubit(5, 2)) ** 0.5) + cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( + circuits[0], expected_circuit + ) + cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( + circuits[1], expected_circuit + ) + + +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') +def test_get_circuit_error_cases(get_program_async): + # Batch program, no index passed + program_batch = cg.EngineProgram('a', 'b', EngineContext()) + get_program_async.return_value = quantum.QuantumProgram(code=_BATCH_PROGRAM_V2) + with pytest.raises(ValueError, match="is a batch program containing 2 circuits"): + _ = program_batch.get_circuit() + + # Batch program, index out of range + with pytest.raises(IndexError, match="Index 2 out of range"): + _ = program_batch.get_circuit(2) + + # Single circuit program, indexing not allowed (except 0 and -1) + program_single = cg.EngineProgram('a', 'b', EngineContext()) + get_program_async.reset_mock() + get_program_async.return_value = quantum.QuantumProgram(code=_PROGRAM_V2) + assert program_single.get_circuit(0) is not None + assert program_single.get_circuit(-1) is not None + with pytest.raises(IndexError, match="is not a batch program, cannot index 1"): + _ = program_single.get_circuit(1) + + +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') +def test_is_batch(get_program_async): + # Single circuit program + program = cg.EngineProgram('a', 'b', EngineContext()) + get_program_async.return_value = quantum.QuantumProgram(code=_PROGRAM_V2) + assert not program.is_batch() + + # Batch program + program_batch = cg.EngineProgram('a', 'b', EngineContext()) + get_program_async.reset_mock() + get_program_async.return_value = quantum.QuantumProgram(code=_BATCH_PROGRAM_V2) + assert program_batch.is_batch() diff --git a/cirq-google/cirq_google/engine/simulated_local_job.py b/cirq-google/cirq_google/engine/simulated_local_job.py index 66a2bddb420..b483d1297d3 100644 --- a/cirq-google/cirq_google/engine/simulated_local_job.py +++ b/cirq-google/cirq_google/engine/simulated_local_job.py @@ -122,12 +122,12 @@ def _execute_results(self) -> Sequence[Sequence[EngineResult]]: Returns: a List of results from the sweep's execution. """ - reps, sweeps = self.get_repetitions_and_sweeps() + reps = self._repetitions + sweeps = self._sweeps parent = self.program() - batch_size = parent.batch_size() try: self._state = quantum.ExecutionStatus.State.RUNNING - programs = [parent.get_circuit(n) for n in range(batch_size)] + programs = parent.get_circuits() if len(sweeps) == 1 and len(programs) > 1: sweeps = sweeps * len(programs) batch_results = self._sampler.run_batch(