From 7133c6ef03eaa02114fef05a813d8565500bd1b0 Mon Sep 17 00:00:00 2001 From: Doug Strain Date: Thu, 14 May 2026 09:34:03 -0700 Subject: [PATCH 1/4] Fix issues with multi-circuit size in Engine - Fix Program.get_circuit to give more descriptive error messages. - Fix batch_size() to work with multiple circuits. - Fix get_repetitions_and_sweeps to work with multi-circuits. - Add get_circuits() to get all the circuits. --- .../cirq_google/engine/abstract_job.py | 11 +- .../cirq_google/engine/abstract_job_test.py | 2 +- .../cirq_google/engine/abstract_local_job.py | 43 ++++- .../engine/abstract_local_job_test.py | 41 ++++- .../engine/abstract_local_program.py | 43 ++++- .../engine/abstract_local_program_test.py | 17 +- .../cirq_google/engine/abstract_program.py | 16 +- cirq-google/cirq_google/engine/engine_job.py | 45 ++++- .../cirq_google/engine/engine_job_test.py | 161 ++++++++++++++++- .../cirq_google/engine/engine_program.py | 93 ++++++++-- .../cirq_google/engine/engine_program_test.py | 168 +++++++++++++++++- .../cirq_google/engine/simulated_local_job.py | 6 +- 12 files changed, 588 insertions(+), 58 deletions(-) diff --git a/cirq-google/cirq_google/engine/abstract_job.py b/cirq-google/cirq_google/engine/abstract_job.py index dc4e3a8d193..94ad91fd753 100644 --- a/cirq-google/cirq_google/engine/abstract_job.py +++ b/cirq-google/cirq_google/engine/abstract_job.py @@ -141,9 +141,14 @@ def failure(self) -> tuple[str, str] | None: """Return failure code and message of the job if present.""" @abc.abstractmethod - def get_repetitions_and_sweeps(self) -> tuple[int, list[cirq.Sweep]]: + def get_repetitions_and_sweeps(self, circuit_num: int | None = None) -> tuple[int, list[cirq.Sweep]]: """Returns the repetitions and sweeps for the job. + Args: + circuit_num: if this is a batch job, the index of the circuit + to return the sweeps for. This argument is zero-indexed. + Negative values index from the end of the list. + Returns: A tuple of the repetition count and list of sweeps. """ @@ -159,11 +164,11 @@ def get_calibration(self) -> calibration.Calibration | None: one was captured, else None.""" @abc.abstractmethod - def get_circuit(self, program_num: int | None = None) -> cirq.Circuit: + def get_circuit(self, circuit_num: int | None = None) -> cirq.Circuit: """Returns the cirq Circuit for the job. Args: - program_num: if this is a multi-circuit job, the index of the circuit + circuit_num: if this is a multi-circuit job, the index of the circuit to return. This argument is zero-indexed. Negative values index from the end of the list. Ignored if not multi-circuit. diff --git a/cirq-google/cirq_google/engine/abstract_job_test.py b/cirq-google/cirq_google/engine/abstract_job_test.py index 179aa86df91..822f1f61415 100644 --- a/cirq-google/cirq_google/engine/abstract_job_test.py +++ b/cirq-google/cirq_google/engine/abstract_job_test.py @@ -80,7 +80,7 @@ def get_processor(self): def get_calibration(self): pass - def get_circuit(self, program_num: int | None = None) -> cirq.Circuit: + def get_circuit(self, circuit_num: int | None = None) -> cirq.Circuit: return cirq.Circuit() def cancel(self) -> None: diff --git a/cirq-google/cirq_google/engine/abstract_local_job.py b/cirq-google/cirq_google/engine/abstract_local_job.py index d3fc5b0b8db..620bd8289bc 100644 --- a/cirq-google/cirq_google/engine/abstract_local_job.py +++ b/cirq-google/cirq_google/engine/abstract_local_job.py @@ -157,13 +157,46 @@ def processor_ids(self) -> list[str]: """Returns the processor ids provided when the job was created.""" return [self._processor_id] - def get_repetitions_and_sweeps(self) -> tuple[int, list[cirq.Sweep]]: + def get_repetitions_and_sweeps(self, circuit_num: int | None = None) -> tuple[int, list[cirq.Sweep]]: """Returns the repetitions and sweeps for the job. + Args: + circuit_num: if this is a batch job, the index of the circuit + to return the sweeps for. This argument is zero-indexed. + Negative values index from the end of the list. + Returns: A tuple of the repetition count and list of sweeps. """ - return (self._repetitions, self._sweeps) + is_batch = self.program().is_batch() + batch_size = self.program().batch_size() if is_batch else 1 + + is_mapped = is_batch and len(self._sweeps) == batch_size and len(self._sweeps) > 1 + + if circuit_num is None: + if is_mapped: + raise ValueError( + f"This is a batch job with {len(self._sweeps)} mapped sweeps. " + "Please specify `circuit_num` to get sweeps for a specific circuit." + ) + return (self._repetitions, self._sweeps) + + if not is_batch: + if circuit_num != 0 and circuit_num != -1: + raise IndexError(f"Job is not a batch job, cannot index {circuit_num}") + return (self._repetitions, self._sweeps) + + if not is_mapped: + # Shared sweeps in a batch job, return all of them + return (self._repetitions, self._sweeps) + + # Mapped sweeps in a batch job + try: + return (self._repetitions, [self._sweeps[circuit_num]]) + except IndexError: + raise IndexError( + f"Index {circuit_num} out of range for batch job sweeps of size {len(self._sweeps)}." + ) def get_processor(self) -> AbstractLocalProcessor: """Returns the AbstractProcessor for the processor the job is/was run on, @@ -175,15 +208,15 @@ def get_calibration(self) -> calibration.Calibration | None: from the parent Engine object.""" return self.get_processor().get_latest_calibration(int(self._create_time.timestamp())) - def get_circuit(self, program_num: int | None = None) -> cirq.Circuit: + def get_circuit(self, circuit_num: int | None = None) -> cirq.Circuit: """Returns the cirq Circuit for the job. Args: - program_num: if this is a multi-circuit job, the index of the circuit + circuit_num: if this is a multi-circuit job, the index of the circuit to return. This argument is zero-indexed. Negative values index from the end of the list. Ignored if not multi-circuit. Returns: The job's cirq Circuit. """ - return self.program().get_circuit(program_num) + return self.program().get_circuit(circuit_num) diff --git a/cirq-google/cirq_google/engine/abstract_local_job_test.py b/cirq-google/cirq_google/engine/abstract_local_job_test.py index 38fccc69971..e77f483de88 100644 --- a/cirq-google/cirq_google/engine/abstract_local_job_test.py +++ b/cirq-google/cirq_google/engine/abstract_local_job_test.py @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING from unittest import mock +import pytest import cirq from cirq_google.cloud import quantum from cirq_google.engine.abstract_local_job import AbstractLocalJob @@ -75,14 +76,52 @@ def test_description_and_labels(): def test_reps_and_sweeps(): + # Single program (non-batch) + mock_program = mock.Mock() + mock_program.is_batch.return_value = False job = NothingJob( job_id='test', processor_id='grill', - parent_program=None, + parent_program=mock_program, repetitions=100, sweeps=[cirq.Linspace('t', 0, 10, 0.1)], ) assert job.get_repetitions_and_sweeps() == (100, [cirq.Linspace('t', 0, 10, 0.1)]) + assert job.get_repetitions_and_sweeps(0) == (100, [cirq.Linspace('t', 0, 10, 0.1)]) + with pytest.raises(IndexError, match="Job is not a batch job"): + _ = job.get_repetitions_and_sweeps(1) + + # Batch program, shared sweep + mock_program_batch = mock.Mock() + mock_program_batch.is_batch.return_value = True + mock_program_batch.batch_size.return_value = 2 + job_batch_shared = NothingJob( + job_id='test', + processor_id='grill', + parent_program=mock_program_batch, + repetitions=100, + sweeps=[cirq.Linspace('t', 0, 10, 0.1)], + ) + # Shared sweep, works with None or any index + assert job_batch_shared.get_repetitions_and_sweeps() == (100, [cirq.Linspace('t', 0, 10, 0.1)]) + assert job_batch_shared.get_repetitions_and_sweeps(0) == (100, [cirq.Linspace('t', 0, 10, 0.1)]) + assert job_batch_shared.get_repetitions_and_sweeps(1) == (100, [cirq.Linspace('t', 0, 10, 0.1)]) + + # Batch program, mapped sweeps + job_batch_mapped = NothingJob( + job_id='test', + processor_id='grill', + parent_program=mock_program_batch, + repetitions=100, + sweeps=[cirq.Linspace('t', 0, 10, 0.1), cirq.Linspace('u', 0, 5, 0.5)], + ) + # Mapped sweeps, requires index + with pytest.raises(ValueError, match="mapped sweeps"): + _ = job_batch_mapped.get_repetitions_and_sweeps() + assert job_batch_mapped.get_repetitions_and_sweeps(0) == (100, [cirq.Linspace('t', 0, 10, 0.1)]) + assert job_batch_mapped.get_repetitions_and_sweeps(1) == (100, [cirq.Linspace('u', 0, 5, 0.5)]) + with pytest.raises(IndexError, match="Index 2 out of range"): + _ = job_batch_mapped.get_repetitions_and_sweeps(2) def test_create_update_time(): diff --git a/cirq-google/cirq_google/engine/abstract_local_program.py b/cirq-google/cirq_google/engine/abstract_local_program.py index 685a49fb45e..5fea814729f 100644 --- a/cirq-google/cirq_google/engine/abstract_local_program.py +++ b/cirq-google/cirq_google/engine/abstract_local_program.py @@ -39,7 +39,11 @@ class AbstractLocalProgram(AbstractProgram): need to implement abstract methods. """ - def __init__(self, circuits: list[cirq.Circuit], engine: AbstractLocalEngine): + def __init__( + self, + circuits: list[cirq.Circuit], + engine: AbstractLocalEngine, + ): if not circuits: raise ValueError('No circuits provided to program.') self._create_time = datetime.datetime.now() @@ -190,22 +194,47 @@ def remove_labels(self, keys: list[str]) -> AbstractProgram: del self._labels[key] return self - def get_circuit(self, program_num: int | None = None) -> cirq.Circuit: + def get_circuit(self, circuit_num: int | None = None) -> cirq.Circuit: """Returns the cirq Circuit for the program. This is only supported if the program was created with the V2 protos. Args: - program_num: if this is a multi-circuit program, the index of the circuit + circuit_num: if this is a multi-circuit program, the index of the circuit to return. This argument is zero-indexed. Negative values indexing from the end of the list. Returns: The program's cirq Circuit. """ - if program_num is not None: - return self._circuits[program_num] - return self._circuits[0] + if circuit_num is None: + if self.is_batch(): + raise ValueError( + f"This program is a batch program containing {len(self._circuits)} circuits. " + "Please specify `circuit_num` to get a specific circuit, " + "or use `get_circuits()` to get all of them." + ) + return self._circuits[0] + try: + return self._circuits[circuit_num] + except IndexError: + raise IndexError( + f"Index {circuit_num} out of range for batch program of size {len(self._circuits)}." + ) + + def get_circuits(self) -> list[cirq.Circuit]: + """Returns all the cirq Circuits for the program.""" + return self._circuits + + def is_batch(self) -> bool: + """Returns True if the program is a batch program.""" + return len(self._circuits) > 1 def batch_size(self) -> int: - """Returns the number of programs in a batch program.""" + """Returns the number of programs in a batch program. + + Raises: + ValueError: if the program created was not a batch program. + """ + if not self.is_batch(): + raise ValueError("This program is not a batch program.") return len(self._circuits) diff --git a/cirq-google/cirq_google/engine/abstract_local_program_test.py b/cirq-google/cirq_google/engine/abstract_local_program_test.py index d658618ab6b..19a1a288cc0 100644 --- a/cirq-google/cirq_google/engine/abstract_local_program_test.py +++ b/cirq-google/cirq_google/engine/abstract_local_program_test.py @@ -149,12 +149,25 @@ def test_description_and_labels(): def test_circuit(): circuit1 = cirq.Circuit(cirq.X(cirq.LineQubit(1))) circuit2 = cirq.Circuit(cirq.Y(cirq.LineQubit(2))) + + # Single circuit, non-batch program = NothingProgram([circuit1], None) - assert program.batch_size() == 1 + assert not program.is_batch() + with pytest.raises(ValueError, match="not a batch program"): + _ = program.batch_size() assert program.get_circuit() == circuit1 assert program.get_circuit(0) == circuit1 - assert program.batch_size() == 1 + assert program.get_circuits() == [circuit1] + + # Multi circuit (always batch) program = NothingProgram([circuit1, circuit2], None) + assert program.is_batch() assert program.batch_size() == 2 + with pytest.raises(ValueError, match="batch program containing 2 circuits"): + _ = program.get_circuit() assert program.get_circuit(0) == circuit1 assert program.get_circuit(1) == circuit2 + assert program.get_circuits() == [circuit1, circuit2] + + with pytest.raises(IndexError): + _ = program.get_circuit(2) diff --git a/cirq-google/cirq_google/engine/abstract_program.py b/cirq-google/cirq_google/engine/abstract_program.py index d703c71a4fc..e0ed7519e2c 100644 --- a/cirq-google/cirq_google/engine/abstract_program.py +++ b/cirq-google/cirq_google/engine/abstract_program.py @@ -158,12 +158,12 @@ def remove_labels(self, keys: list[str]) -> AbstractProgram: """ @abc.abstractmethod - def get_circuit(self, program_num: int | None = None) -> cirq.Circuit: + def get_circuit(self, circuit_num: int | None = None) -> cirq.Circuit: """Returns the cirq Circuit for the program. This is only supported if the program was created with the V2 protos. Args: - program_num: if this is a multi-circuit program, the index of the circuit + circuit_num: if this is a multi-circuit program, the index of the circuit to return. This argument is zero-indexed. Negative values indexing from the end of the list. @@ -171,6 +171,18 @@ def get_circuit(self, program_num: int | None = None) -> cirq.Circuit: The program's cirq Circuit. """ + @abc.abstractmethod + def get_circuits(self) -> list[cirq.Circuit]: + """Returns all the cirq Circuits for the program. + + Returns: + A list of the program's cirq Circuits. + """ + + @abc.abstractmethod + def is_batch(self) -> bool: + """Returns True if the program is a batch program.""" + @abc.abstractmethod def batch_size(self) -> int: """Returns the number of programs in a batch program. diff --git a/cirq-google/cirq_google/engine/engine_job.py b/cirq-google/cirq_google/engine/engine_job.py index 1746789ee67..2d532845487 100644 --- a/cirq-google/cirq_google/engine/engine_job.py +++ b/cirq-google/cirq_google/engine/engine_job.py @@ -226,15 +226,50 @@ def failure(self) -> tuple[str, str] | None: return (failure.error_code.name, failure.error_message) return None - def get_repetitions_and_sweeps(self) -> tuple[int, list[cirq.Sweep]]: + def get_repetitions_and_sweeps(self, circuit_num: int | None = None) -> tuple[int, list[cirq.Sweep]]: """Returns the repetitions and sweeps for the Quantum Engine job. + Args: + circuit_num: if this is a batch job, the index of the circuit + to return the sweeps for. This argument is zero-indexed. + Negative values index from the end of the list. + Returns: A tuple of the repetition count and list of sweeps. """ if self._job is None or self._job.run_context is None: self._job = self._get_job(return_run_context=True) - return _deserialize_run_context(self._job.run_context) + reps, sweeps = _deserialize_run_context(self._job.run_context) + + is_batch = self.program().is_batch() + batch_size = self.program().batch_size() if is_batch else 1 + + is_mapped = is_batch and len(sweeps) == batch_size and len(sweeps) > 1 + + if circuit_num is None: + if is_mapped: + raise ValueError( + f"This is a batch job with {len(sweeps)} mapped sweeps. " + "Please specify `circuit_num` to get sweeps for a specific circuit." + ) + return (reps, sweeps) + + if not is_batch: + if circuit_num != 0 and circuit_num != -1: + raise IndexError(f"Job is not a batch job, cannot index {circuit_num}") + return (reps, sweeps) + + if not is_mapped: + # Shared sweeps in a batch job, return all of them + return (reps, sweeps) + + # Mapped sweeps in a batch job + try: + return (reps, [sweeps[circuit_num]]) + except IndexError: + raise IndexError( + f"Index {circuit_num} out of range for batch job sweeps of size {len(sweeps)}." + ) def get_processor(self) -> engine_processor.EngineProcessor | None: """Returns the EngineProcessor for the processor the job is/was run on, @@ -258,18 +293,18 @@ def get_calibration(self) -> calibration.Calibration | None: metrics = v2.metrics_pb2.MetricsSnapshot.FromString(response.data.value) return calibration.Calibration(metrics) - async def get_circuit_async(self, program_num: int | None = None) -> cirq.Circuit: + async def get_circuit_async(self, circuit_num: int | None = None) -> cirq.Circuit: """Returns the cirq Circuit for the Quantum Engine job. Args: - program_num: if this is a multi-circuit job, the index of the circuit + circuit_num: if this is a multi-circuit job, the index of the circuit to return. This argument is zero-indexed. Negative values indexing from the end of the list. Returns: The job's cirq Circuit. """ - return await self.program().get_circuit_async(program_num) + return await self.program().get_circuit_async(circuit_num) get_circuit = duet.sync(get_circuit_async) diff --git a/cirq-google/cirq_google/engine/engine_job_test.py b/cirq-google/cirq_google/engine/engine_job_test.py index a75593566db..0cc8cf594ef 100644 --- a/cirq-google/cirq_google/engine/engine_job_test.py +++ b/cirq-google/cirq_google/engine/engine_job_test.py @@ -31,6 +31,111 @@ from cirq_google.engine.stream_manager import StreamError +_PROGRAM_V2 = util.pack_any( + Merge( + """language { + gate_set: "v2_5" + arg_function_language: "exp" +} +circuit { + scheduling_strategy: MOMENT_BY_MOMENT + moments { + operations { + qubit_constant_index: 0 + phasedxpowgate { + phase_exponent { + float_value: 0.0 + } + exponent { + float_value: 0.5 + } + } + } + } + moments { + operations { + qubit_constant_index: 0 + measurementgate { + key { + arg_value { + string_value: "result" + } + } + invert_mask { + arg_value { + bool_values { + } + } + } + } + } + } +} +constants { + qubit { + id: "5_2" + } +} +""", + v2.program_pb2.Program(), + ) +) + +_BATCH_PROGRAM_V2 = util.pack_any( + Merge( + """language { + gate_set: "v2_5" + arg_function_language: "exp" +} +keyed_circuits { + key: "c1" + circuit { + scheduling_strategy: MOMENT_BY_MOMENT + moments { + operations { + qubit_constant_index: 0 + phasedxpowgate { + phase_exponent { + float_value: 0.0 + } + exponent { + float_value: 0.5 + } + } + } + } + } +} +keyed_circuits { + key: "c2" + circuit { + scheduling_strategy: MOMENT_BY_MOMENT + moments { + operations { + qubit_constant_index: 0 + phasedxpowgate { + phase_exponent { + float_value: 0.0 + } + exponent { + float_value: 0.5 + } + } + } + } + } +} +constants { + qubit { + id: "5_2" + } +} +""", + v2.program_pb2.Program(), + ) +) + + @pytest.fixture(scope='module', autouse=True) def mock_grpc_client(): with mock.patch( @@ -220,9 +325,11 @@ def test_failure_with_no_error(): assert not job.failure() +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') @mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_async') -def test_get_repetitions_and_sweeps(get_job): - job = cg.EngineJob('a', 'b', 'steve', EngineContext()) +def test_get_repetitions_and_sweeps(get_job, get_program): + # Single program (non-batch) + get_program.return_value = quantum.QuantumProgram(code=_PROGRAM_V2) get_job.return_value = quantum.QuantumJob( run_context=util.pack_any( v2.run_context_pb2.RunContext( @@ -230,8 +337,56 @@ def test_get_repetitions_and_sweeps(get_job): ) ) ) + job = cg.EngineJob('a', 'b', 'steve', EngineContext()) assert job.get_repetitions_and_sweeps() == (10, [cirq.UnitSweep]) - get_job.assert_called_once_with('a', 'b', 'steve', True) + assert job.get_repetitions_and_sweeps(0) == (10, [cirq.UnitSweep]) + with pytest.raises(IndexError, match="Job is not a batch job"): + _ = job.get_repetitions_and_sweeps(1) + + # Batch program, shared sweep + get_program.reset_mock() + get_program.return_value = quantum.QuantumProgram(code=_BATCH_PROGRAM_V2) + get_job.reset_mock() + get_job.return_value = quantum.QuantumJob( + run_context=util.pack_any( + v2.run_context_pb2.RunContext( + parameter_sweeps=[v2.run_context_pb2.ParameterSweep(repetitions=10)] + ) + ) + ) + job_batch_shared = cg.EngineJob('a', 'b', 'steve', EngineContext()) + assert job_batch_shared.get_repetitions_and_sweeps() == (10, [cirq.UnitSweep]) + assert job_batch_shared.get_repetitions_and_sweeps(0) == (10, [cirq.UnitSweep]) + assert job_batch_shared.get_repetitions_and_sweeps(1) == (10, [cirq.UnitSweep]) + + # Batch program, mapped sweeps + get_job.reset_mock() + get_job.return_value = quantum.QuantumJob( + run_context=util.pack_any( + v2.run_context_pb2.RunContext( + parameter_sweeps=[ + v2.run_context_pb2.ParameterSweep(repetitions=10), + v2.run_context_pb2.ParameterSweep( + repetitions=10, + sweep=v2.run_context_pb2.Sweep( + single_sweep=v2.run_context_pb2.SingleSweep( + parameter_key='t', + points=v2.run_context_pb2.Points(points_double=[1.0, 2.0]) + ) + ) + ) + ] + ) + ) + ) + job_batch_mapped = cg.EngineJob('a', 'b', 'steve', EngineContext()) + # Mapped sweeps, requires index + with pytest.raises(ValueError, match="mapped sweeps"): + _ = job_batch_mapped.get_repetitions_and_sweeps() + assert job_batch_mapped.get_repetitions_and_sweeps(0) == (10, [cirq.UnitSweep]) + assert job_batch_mapped.get_repetitions_and_sweeps(1) == (10, [cirq.Points('t', [1.0, 2.0])]) + with pytest.raises(IndexError, match="Index 2 out of range"): + _ = job_batch_mapped.get_repetitions_and_sweeps(2) @mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_async') diff --git a/cirq-google/cirq_google/engine/engine_program.py b/cirq-google/cirq_google/engine/engine_program.py index 236b22a7d36..27eac1afd01 100644 --- a/cirq-google/cirq_google/engine/engine_program.py +++ b/cirq-google/cirq_google/engine/engine_program.py @@ -352,35 +352,87 @@ async def remove_labels_async(self, keys: list[str]) -> EngineProgram: remove_labels = duet.sync(remove_labels_async) - async def get_circuit_async(self, program_num: int | None = None) -> cirq.Circuit: + async def _get_proto_async(self) -> v2.program_pb2.Program: + if not hasattr(self, '_proto') or self._proto is None: + if self._program is None or not self._program.code or not self._program.code.type_url: + self._program = await self.context.client.get_program_async( + self.project_id, self.program_id, True + ) + self._proto = _deserialize_to_proto(self._program.code) + return self._proto + + async def is_batch_async(self) -> bool: + """Returns True if the program is a batch program.""" + proto = await self._get_proto_async() + return proto.WhichOneof('program') != 'circuit' + + is_batch = duet.sync(is_batch_async) + + async def get_circuit_async(self, circuit_num: int | None = None) -> cirq.Circuit: """Returns the cirq Circuit for the Quantum Engine program. This is only supported if the program was created with the V2 protos. Args: - program_num: if this is a multi-circuit program, the index of the circuit + circuit_num: if this is a multi-circuit program, the index of the circuit to return. This argument is zero-indexed. Negative values indexing from the end of the list. Returns: The program's cirq Circuit. """ - # The code field is an any_pb2.Any and is always set. But if the program has not - # been fetched this field may be empty, which we can see by checking the type_url. - if self._program is None or not self._program.code or not self._program.code.type_url: - self._program = await self.context.client.get_program_async( - self.project_id, self.program_id, True + proto = await self._get_proto_async() + is_batch = await self.is_batch_async() + + if circuit_num is None: + if is_batch: + raise ValueError( + f"Program {self.program_id} is a batch program containing " + f"{len(proto.keyed_circuits)} circuits. " + "Please specify `circuit_num` to get a specific circuit, " + "or use `get_circuits()` to get all of them." + ) + return circuit_serializer.CIRCUIT_SERIALIZER.deserialize(proto) + + if not is_batch: + if circuit_num != 0 and circuit_num != -1: + raise IndexError( + f"Program {self.program_id} is not a batch program, cannot index {circuit_num}" + ) + return circuit_serializer.CIRCUIT_SERIALIZER.deserialize(proto) + + deserialized = circuit_serializer.CIRCUIT_SERIALIZER.deserialize_multi_program(proto) + try: + return cast(cirq.Circuit, deserialized[circuit_num][2]) + except IndexError: + raise IndexError( + f"Index {circuit_num} out of range for batch program {self.program_id} " + f"of size {len(deserialized)}." ) - return _deserialize_program(self._program.code, program_num) get_circuit = duet.sync(get_circuit_async) + async def get_circuits_async(self) -> list[cirq.Circuit]: + """Returns all the cirq Circuits for the Quantum Engine program.""" + proto = await self._get_proto_async() + serializer = circuit_serializer.CIRCUIT_SERIALIZER + if not await self.is_batch_async(): + return [serializer.deserialize(proto)] + + deserialized = serializer.deserialize_multi_program(proto) + return [triple[2] for triple in deserialized] + + get_circuits = duet.sync(get_circuits_async) + async def batch_size_async(self) -> int: """Returns the number of programs in a batch program. Raises: ValueError: if the program created was not a batch program. """ - raise NotImplementedError("Batch programs are no longer supported.") + proto = await self._get_proto_async() + if not await self.is_batch_async(): + raise ValueError(f"Program {self.program_id} is not a batch program.") + return len(proto.keyed_circuits) batch_size = duet.sync(batch_size_async) @@ -407,19 +459,26 @@ def __str__(self) -> str: return f'EngineProgram(project_id=\'{self.project_id}\', program_id=\'{self.program_id}\')' -def _deserialize_program(code: any_pb2.Any, program_num: int | None = None) -> cirq.Circuit: +def _deserialize_to_proto(code: any_pb2.Any) -> v2.program_pb2.Program: import cirq_google.engine.engine as engine_base code_type = code.type_url[len(engine_base.TYPE_PREFIX) :] - program = None if code_type == 'cirq.google.api.v1.Program' or code_type == 'cirq.api.google.v1.Program': raise ValueError('deserializing a v1 Program is not supported') elif code_type == 'cirq.google.api.v2.Program' or code_type == 'cirq.api.google.v2.Program': - program = v2.program_pb2.Program.FromString(code.value) - if program is not None: - serializer = circuit_serializer.CIRCUIT_SERIALIZER - if program_num is not None: - return cast(cirq.Circuit, serializer.deserialize_multi_program(program)[program_num][2]) - return serializer.deserialize(program) + return v2.program_pb2.Program.FromString(code.value) raise ValueError(f'unsupported program type: {code_type}') + + +def _deserialize_program(code: any_pb2.Any, circuit_num: int | None = None) -> cirq.Circuit: + program = _deserialize_to_proto(code) + serializer = circuit_serializer.CIRCUIT_SERIALIZER + if circuit_num is not None: + try: + return cast(cirq.Circuit, serializer.deserialize_multi_program(program)[circuit_num][2]) + except IndexError: + raise IndexError( + f"Index {circuit_num} out of range for batch program." + ) + return serializer.deserialize(program) diff --git a/cirq-google/cirq_google/engine/engine_program_test.py b/cirq-google/cirq_google/engine/engine_program_test.py index 885ffa3228f..135b4d1e53d 100644 --- a/cirq-google/cirq_google/engine/engine_program_test.py +++ b/cirq-google/cirq_google/engine/engine_program_test.py @@ -81,6 +81,61 @@ ) +_BATCH_PROGRAM_V2 = util.pack_any( + Merge( + """language { + gate_set: "v2_5" + arg_function_language: "exp" +} +keyed_circuits { + key: "c1" + circuit { + scheduling_strategy: MOMENT_BY_MOMENT + moments { + operations { + qubit_constant_index: 0 + phasedxpowgate { + phase_exponent { + float_value: 0.0 + } + exponent { + float_value: 0.5 + } + } + } + } + } +} +keyed_circuits { + key: "c2" + circuit { + scheduling_strategy: MOMENT_BY_MOMENT + moments { + operations { + qubit_constant_index: 0 + phasedxpowgate { + phase_exponent { + float_value: 0.0 + } + exponent { + float_value: 0.5 + } + } + } + } + } +} +constants { + qubit { + id: "5_2" + } +} +""", + v2.program_pb2.Program(), + ) +) + + @mock.patch('cirq_google.engine.engine_client.EngineClient.create_job_async') def test_run_sweeps_delegation(create_job_async): create_job_async.return_value = ('steve', quantum.QuantumJob()) @@ -315,29 +370,42 @@ def test_get_circuit_v2(get_program_async, include_empty_program: bool) -> None: ) get_program_async.assert_called_once_with('a', 'b', True) + # Test indexing on a batch program + program_batch = cg.EngineProgram('a', 'b', EngineContext()) + get_program_async.reset_mock() + get_program_async.return_value = quantum.QuantumProgram(code=_BATCH_PROGRAM_V2) with mock.patch( 'cirq_google.serialization.circuit_serializer.CIRCUIT_SERIALIZER.deserialize_multi_program' ) as deserialize_multi_program: deserialize_multi_program.return_value = [('key0', (), circuit), ('key1', (), circuit)] - assert program.get_circuit(program_num=1) is circuit + assert program_batch.get_circuit(circuit_num=1) is circuit deserialize_multi_program.assert_called_once() @duet.sync async def test_get_circuit_async(): context = EngineContext() + circuit = cirq.Circuit(cirq.X(cirq.GridQubit(5, 2)) ** 0.5) + + # Single circuit program = cg.EngineProgram('a', 'b', context) - circuit = cirq.Circuit() with mock.patch.object( context.client, 'get_program_async', new_callable=mock.AsyncMock ) as get_program_async: get_program_async.return_value = quantum.QuantumProgram(code=_PROGRAM_V2) - with mock.patch( - 'cirq_google.engine.engine_program._deserialize_program' - ) as deserialize_program: - deserialize_program.return_value = circuit - assert await program.get_circuit_async(1) == circuit - deserialize_program.assert_called_once_with(mock.ANY, 1) + c = await program.get_circuit_async() + assert isinstance(c, cirq.Circuit) + + # Batch circuit + program_batch = cg.EngineProgram('a', 'b', context) + with mock.patch.object( + context.client, 'get_program_async', new_callable=mock.AsyncMock + ) as get_program_async: + get_program_async.return_value = quantum.QuantumProgram(code=_BATCH_PROGRAM_V2) + c = await program_batch.get_circuit_async(1) + cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( + c, circuit + ) def test_deserialize_program(): @@ -347,7 +415,7 @@ def test_deserialize_program(): ) as mock_serializer: cg.engine.engine_program._deserialize_program(code) mock_serializer.deserialize.assert_called_once() - cg.engine.engine_program._deserialize_program(code, program_num=1) + cg.engine.engine_program._deserialize_program(code, circuit_num=1) mock_serializer.deserialize_multi_program.assert_called_once() @@ -390,3 +458,85 @@ def test_delete_jobs(delete_job_async): def test_str(): program = cg.EngineProgram('my-proj', 'my-prog', EngineContext()) assert str(program) == 'EngineProgram(project_id=\'my-proj\', program_id=\'my-prog\')' + + +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') +def test_batch_size(get_program_async): + # Single circuit program (not a batch) + program = cg.EngineProgram('a', 'b', EngineContext()) + get_program_async.return_value = quantum.QuantumProgram(code=_PROGRAM_V2) + with pytest.raises(ValueError, match="not a batch program"): + _ = program.batch_size() + + # Batch program + program_batch = cg.EngineProgram('a', 'b', EngineContext()) + get_program_async.reset_mock() + get_program_async.return_value = quantum.QuantumProgram(code=_BATCH_PROGRAM_V2) + assert program_batch.batch_size() == 2 + get_program_async.assert_called_once_with('a', 'b', True) + + +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') +def test_get_circuits(get_program_async): + circuit = cirq.Circuit( + cirq.X(cirq.GridQubit(5, 2)) ** 0.5, cirq.measure(cirq.GridQubit(5, 2), key='result') + ) + + # Single circuit program + program = cg.EngineProgram('a', 'b', EngineContext()) + get_program_async.return_value = quantum.QuantumProgram(code=_PROGRAM_V2) + circuits = program.get_circuits() + assert len(circuits) == 1 + cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( + circuits[0], circuit + ) + + # Batch program + program_batch = cg.EngineProgram('a', 'b', EngineContext()) + get_program_async.reset_mock() + get_program_async.return_value = quantum.QuantumProgram(code=_BATCH_PROGRAM_V2) + circuits = program_batch.get_circuits() + assert len(circuits) == 2 + expected_circuit = cirq.Circuit(cirq.X(cirq.GridQubit(5, 2)) ** 0.5) + cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( + circuits[0], expected_circuit + ) + cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( + circuits[1], expected_circuit + ) + + +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') +def test_get_circuit_error_cases(get_program_async): + # Batch program, no index passed + program_batch = cg.EngineProgram('a', 'b', EngineContext()) + get_program_async.return_value = quantum.QuantumProgram(code=_BATCH_PROGRAM_V2) + with pytest.raises(ValueError, match="is a batch program containing 2 circuits"): + _ = program_batch.get_circuit() + + # Batch program, index out of range + with pytest.raises(IndexError, match="Index 2 out of range"): + _ = program_batch.get_circuit(2) + + # Single circuit program, indexing not allowed (except 0 and -1) + program_single = cg.EngineProgram('a', 'b', EngineContext()) + get_program_async.reset_mock() + get_program_async.return_value = quantum.QuantumProgram(code=_PROGRAM_V2) + assert program_single.get_circuit(0) is not None + assert program_single.get_circuit(-1) is not None + with pytest.raises(IndexError, match="is not a batch program, cannot index 1"): + _ = program_single.get_circuit(1) + + +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') +def test_is_batch(get_program_async): + # Single circuit program + program = cg.EngineProgram('a', 'b', EngineContext()) + get_program_async.return_value = quantum.QuantumProgram(code=_PROGRAM_V2) + assert not program.is_batch() + + # Batch program + program_batch = cg.EngineProgram('a', 'b', EngineContext()) + get_program_async.reset_mock() + get_program_async.return_value = quantum.QuantumProgram(code=_BATCH_PROGRAM_V2) + assert program_batch.is_batch() diff --git a/cirq-google/cirq_google/engine/simulated_local_job.py b/cirq-google/cirq_google/engine/simulated_local_job.py index 66a2bddb420..b483d1297d3 100644 --- a/cirq-google/cirq_google/engine/simulated_local_job.py +++ b/cirq-google/cirq_google/engine/simulated_local_job.py @@ -122,12 +122,12 @@ def _execute_results(self) -> Sequence[Sequence[EngineResult]]: Returns: a List of results from the sweep's execution. """ - reps, sweeps = self.get_repetitions_and_sweeps() + reps = self._repetitions + sweeps = self._sweeps parent = self.program() - batch_size = parent.batch_size() try: self._state = quantum.ExecutionStatus.State.RUNNING - programs = [parent.get_circuit(n) for n in range(batch_size)] + programs = parent.get_circuits() if len(sweeps) == 1 and len(programs) > 1: sweeps = sweeps * len(programs) batch_results = self._sampler.run_batch( From 6a73c66c61de9fc5f47a908ddf90741a3815514a Mon Sep 17 00:00:00 2001 From: Doug Strain Date: Thu, 14 May 2026 09:54:27 -0700 Subject: [PATCH 2/4] Fix coverage and format. --- cirq-google/cirq_google/engine/abstract_job.py | 4 +++- .../cirq_google/engine/abstract_local_job.py | 4 +++- .../cirq_google/engine/abstract_local_job_test.py | 1 + .../cirq_google/engine/abstract_local_program.py | 6 +----- cirq-google/cirq_google/engine/engine_job.py | 4 +++- cirq-google/cirq_google/engine/engine_job_test.py | 7 +++---- cirq-google/cirq_google/engine/engine_program.py | 9 ++++----- .../cirq_google/engine/engine_program_test.py | 14 ++++++++------ 8 files changed, 26 insertions(+), 23 deletions(-) diff --git a/cirq-google/cirq_google/engine/abstract_job.py b/cirq-google/cirq_google/engine/abstract_job.py index 94ad91fd753..a641349f87c 100644 --- a/cirq-google/cirq_google/engine/abstract_job.py +++ b/cirq-google/cirq_google/engine/abstract_job.py @@ -141,7 +141,9 @@ def failure(self) -> tuple[str, str] | None: """Return failure code and message of the job if present.""" @abc.abstractmethod - def get_repetitions_and_sweeps(self, circuit_num: int | None = None) -> tuple[int, list[cirq.Sweep]]: + def get_repetitions_and_sweeps( + self, circuit_num: int | None = None + ) -> tuple[int, list[cirq.Sweep]]: """Returns the repetitions and sweeps for the job. Args: diff --git a/cirq-google/cirq_google/engine/abstract_local_job.py b/cirq-google/cirq_google/engine/abstract_local_job.py index 620bd8289bc..3e1fa9c0aec 100644 --- a/cirq-google/cirq_google/engine/abstract_local_job.py +++ b/cirq-google/cirq_google/engine/abstract_local_job.py @@ -157,7 +157,9 @@ def processor_ids(self) -> list[str]: """Returns the processor ids provided when the job was created.""" return [self._processor_id] - def get_repetitions_and_sweeps(self, circuit_num: int | None = None) -> tuple[int, list[cirq.Sweep]]: + def get_repetitions_and_sweeps( + self, circuit_num: int | None = None + ) -> tuple[int, list[cirq.Sweep]]: """Returns the repetitions and sweeps for the job. Args: diff --git a/cirq-google/cirq_google/engine/abstract_local_job_test.py b/cirq-google/cirq_google/engine/abstract_local_job_test.py index e77f483de88..cf3bd338106 100644 --- a/cirq-google/cirq_google/engine/abstract_local_job_test.py +++ b/cirq-google/cirq_google/engine/abstract_local_job_test.py @@ -22,6 +22,7 @@ from unittest import mock import pytest + import cirq from cirq_google.cloud import quantum from cirq_google.engine.abstract_local_job import AbstractLocalJob diff --git a/cirq-google/cirq_google/engine/abstract_local_program.py b/cirq-google/cirq_google/engine/abstract_local_program.py index 5fea814729f..7b91e0ea6a2 100644 --- a/cirq-google/cirq_google/engine/abstract_local_program.py +++ b/cirq-google/cirq_google/engine/abstract_local_program.py @@ -39,11 +39,7 @@ class AbstractLocalProgram(AbstractProgram): need to implement abstract methods. """ - def __init__( - self, - circuits: list[cirq.Circuit], - engine: AbstractLocalEngine, - ): + def __init__(self, circuits: list[cirq.Circuit], engine: AbstractLocalEngine): if not circuits: raise ValueError('No circuits provided to program.') self._create_time = datetime.datetime.now() diff --git a/cirq-google/cirq_google/engine/engine_job.py b/cirq-google/cirq_google/engine/engine_job.py index 2d532845487..a2111e512e2 100644 --- a/cirq-google/cirq_google/engine/engine_job.py +++ b/cirq-google/cirq_google/engine/engine_job.py @@ -226,7 +226,9 @@ def failure(self) -> tuple[str, str] | None: return (failure.error_code.name, failure.error_message) return None - def get_repetitions_and_sweeps(self, circuit_num: int | None = None) -> tuple[int, list[cirq.Sweep]]: + def get_repetitions_and_sweeps( + self, circuit_num: int | None = None + ) -> tuple[int, list[cirq.Sweep]]: """Returns the repetitions and sweeps for the Quantum Engine job. Args: diff --git a/cirq-google/cirq_google/engine/engine_job_test.py b/cirq-google/cirq_google/engine/engine_job_test.py index 0cc8cf594ef..ecaa5939775 100644 --- a/cirq-google/cirq_google/engine/engine_job_test.py +++ b/cirq-google/cirq_google/engine/engine_job_test.py @@ -30,7 +30,6 @@ from cirq_google.engine.engine import EngineContext from cirq_google.engine.stream_manager import StreamError - _PROGRAM_V2 = util.pack_any( Merge( """language { @@ -371,10 +370,10 @@ def test_get_repetitions_and_sweeps(get_job, get_program): sweep=v2.run_context_pb2.Sweep( single_sweep=v2.run_context_pb2.SingleSweep( parameter_key='t', - points=v2.run_context_pb2.Points(points_double=[1.0, 2.0]) + points=v2.run_context_pb2.Points(points_double=[1.0, 2.0]), ) - ) - ) + ), + ), ] ) ) diff --git a/cirq-google/cirq_google/engine/engine_program.py b/cirq-google/cirq_google/engine/engine_program.py index 27eac1afd01..7f6b42cb22e 100644 --- a/cirq-google/cirq_google/engine/engine_program.py +++ b/cirq-google/cirq_google/engine/engine_program.py @@ -62,6 +62,7 @@ def __init__( self.program_id = program_id self.context = context self._program = _program + self._proto: v2.program_pb2.Program | None = None async def run_sweep_async( self, @@ -353,7 +354,7 @@ async def remove_labels_async(self, keys: list[str]) -> EngineProgram: remove_labels = duet.sync(remove_labels_async) async def _get_proto_async(self) -> v2.program_pb2.Program: - if not hasattr(self, '_proto') or self._proto is None: + if self._proto is None: if self._program is None or not self._program.code or not self._program.code.type_url: self._program = await self.context.client.get_program_async( self.project_id, self.program_id, True @@ -419,7 +420,7 @@ async def get_circuits_async(self) -> list[cirq.Circuit]: return [serializer.deserialize(proto)] deserialized = serializer.deserialize_multi_program(proto) - return [triple[2] for triple in deserialized] + return [cast(cirq.Circuit, triple[2]) for triple in deserialized] get_circuits = duet.sync(get_circuits_async) @@ -478,7 +479,5 @@ def _deserialize_program(code: any_pb2.Any, circuit_num: int | None = None) -> c try: return cast(cirq.Circuit, serializer.deserialize_multi_program(program)[circuit_num][2]) except IndexError: - raise IndexError( - f"Index {circuit_num} out of range for batch program." - ) + raise IndexError(f"Index {circuit_num} out of range for batch program.") return serializer.deserialize(program) diff --git a/cirq-google/cirq_google/engine/engine_program_test.py b/cirq-google/cirq_google/engine/engine_program_test.py index 135b4d1e53d..09ebaa5b4e3 100644 --- a/cirq-google/cirq_google/engine/engine_program_test.py +++ b/cirq-google/cirq_google/engine/engine_program_test.py @@ -403,9 +403,7 @@ async def test_get_circuit_async(): ) as get_program_async: get_program_async.return_value = quantum.QuantumProgram(code=_BATCH_PROGRAM_V2) c = await program_batch.get_circuit_async(1) - cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( - c, circuit - ) + cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(c, circuit) def test_deserialize_program(): @@ -419,6 +417,12 @@ def test_deserialize_program(): mock_serializer.deserialize_multi_program.assert_called_once() +def test_deserialize_program_errors(): + # Index out of range + with pytest.raises(IndexError, match="Index 2 out of range"): + cg.engine.engine_program._deserialize_program(_BATCH_PROGRAM_V2, circuit_num=2) + + @pytest.fixture(scope='module', autouse=True) def mock_grpc_client(): with mock.patch( @@ -487,9 +491,7 @@ def test_get_circuits(get_program_async): get_program_async.return_value = quantum.QuantumProgram(code=_PROGRAM_V2) circuits = program.get_circuits() assert len(circuits) == 1 - cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( - circuits[0], circuit - ) + cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(circuits[0], circuit) # Batch program program_batch = cg.EngineProgram('a', 'b', EngineContext()) From bdd7756028df2f4f84f1ad27e21c1852db7a82ff Mon Sep 17 00:00:00 2001 From: Doug Strain Date: Thu, 14 May 2026 11:03:02 -0700 Subject: [PATCH 3/4] line size --- cirq-google/cirq_google/engine/abstract_local_job.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cirq-google/cirq_google/engine/abstract_local_job.py b/cirq-google/cirq_google/engine/abstract_local_job.py index 3e1fa9c0aec..19ffca7e9f9 100644 --- a/cirq-google/cirq_google/engine/abstract_local_job.py +++ b/cirq-google/cirq_google/engine/abstract_local_job.py @@ -197,7 +197,7 @@ def get_repetitions_and_sweeps( return (self._repetitions, [self._sweeps[circuit_num]]) except IndexError: raise IndexError( - f"Index {circuit_num} out of range for batch job sweeps of size {len(self._sweeps)}." + f"Index {circuit_num} out of range for sweeps of size {len(self._sweeps)}." ) def get_processor(self) -> AbstractLocalProcessor: From 21888103bec7cd161480ef75b7c91a79d5a5f294 Mon Sep 17 00:00:00 2001 From: Doug Strain Date: Wed, 20 May 2026 13:26:34 -0700 Subject: [PATCH 4/4] Review comments --- .../cirq_google/engine/abstract_local_job.py | 45 ++++++------------- .../cirq_google/engine/abstract_program.py | 2 +- cirq-google/cirq_google/engine/engine_job.py | 35 ++++++--------- .../cirq_google/engine/engine_job_test.py | 3 -- .../cirq_google/engine/engine_program.py | 2 +- .../cirq_google/engine/engine_program_test.py | 1 + 6 files changed, 31 insertions(+), 57 deletions(-) diff --git a/cirq-google/cirq_google/engine/abstract_local_job.py b/cirq-google/cirq_google/engine/abstract_local_job.py index 19ffca7e9f9..d58c915c0cf 100644 --- a/cirq-google/cirq_google/engine/abstract_local_job.py +++ b/cirq-google/cirq_google/engine/abstract_local_job.py @@ -160,45 +160,28 @@ def processor_ids(self) -> list[str]: def get_repetitions_and_sweeps( self, circuit_num: int | None = None ) -> tuple[int, list[cirq.Sweep]]: - """Returns the repetitions and sweeps for the job. - - Args: - circuit_num: if this is a batch job, the index of the circuit - to return the sweeps for. This argument is zero-indexed. - Negative values index from the end of the list. - - Returns: - A tuple of the repetition count and list of sweeps. - """ is_batch = self.program().is_batch() batch_size = self.program().batch_size() if is_batch else 1 is_mapped = is_batch and len(self._sweeps) == batch_size and len(self._sweeps) > 1 - - if circuit_num is None: - if is_mapped: + if is_mapped: + if circuit_num is None: raise ValueError( f"This is a batch job with {len(self._sweeps)} mapped sweeps. " "Please specify `circuit_num` to get sweeps for a specific circuit." ) - return (self._repetitions, self._sweeps) - - if not is_batch: - if circuit_num != 0 and circuit_num != -1: - raise IndexError(f"Job is not a batch job, cannot index {circuit_num}") - return (self._repetitions, self._sweeps) - - if not is_mapped: - # Shared sweeps in a batch job, return all of them - return (self._repetitions, self._sweeps) - - # Mapped sweeps in a batch job - try: - return (self._repetitions, [self._sweeps[circuit_num]]) - except IndexError: - raise IndexError( - f"Index {circuit_num} out of range for sweeps of size {len(self._sweeps)}." - ) + # Mapped sweeps in a batch job + try: + return (self._repetitions, [self._sweeps[circuit_num]]) + except IndexError: + raise IndexError( + f"Index {circuit_num} out of range for sweeps of size {len(self._sweeps)}." + ) + + # Not a batch job + if not is_batch and circuit_num and circuit_num != -1: + raise IndexError(f"Job is not a batch job, cannot index {circuit_num}") + return (self._repetitions, self._sweeps) def get_processor(self) -> AbstractLocalProcessor: """Returns the AbstractProcessor for the processor the job is/was run on, diff --git a/cirq-google/cirq_google/engine/abstract_program.py b/cirq-google/cirq_google/engine/abstract_program.py index e0ed7519e2c..488f10f4dad 100644 --- a/cirq-google/cirq_google/engine/abstract_program.py +++ b/cirq-google/cirq_google/engine/abstract_program.py @@ -172,7 +172,7 @@ def get_circuit(self, circuit_num: int | None = None) -> cirq.Circuit: """ @abc.abstractmethod - def get_circuits(self) -> list[cirq.Circuit]: + def get_circuits(self) -> Sequence[cirq.Circuit]: """Returns all the cirq Circuits for the program. Returns: diff --git a/cirq-google/cirq_google/engine/engine_job.py b/cirq-google/cirq_google/engine/engine_job.py index a2111e512e2..aaa3f55ddd4 100644 --- a/cirq-google/cirq_google/engine/engine_job.py +++ b/cirq-google/cirq_google/engine/engine_job.py @@ -247,31 +247,24 @@ def get_repetitions_and_sweeps( batch_size = self.program().batch_size() if is_batch else 1 is_mapped = is_batch and len(sweeps) == batch_size and len(sweeps) > 1 - - if circuit_num is None: - if is_mapped: + if is_mapped: + if circuit_num is None: raise ValueError( f"This is a batch job with {len(sweeps)} mapped sweeps. " "Please specify `circuit_num` to get sweeps for a specific circuit." ) - return (reps, sweeps) - - if not is_batch: - if circuit_num != 0 and circuit_num != -1: - raise IndexError(f"Job is not a batch job, cannot index {circuit_num}") - return (reps, sweeps) - - if not is_mapped: - # Shared sweeps in a batch job, return all of them - return (reps, sweeps) - - # Mapped sweeps in a batch job - try: - return (reps, [sweeps[circuit_num]]) - except IndexError: - raise IndexError( - f"Index {circuit_num} out of range for batch job sweeps of size {len(sweeps)}." - ) + # Mapped sweeps in a batch job + try: + return (reps, [sweeps[circuit_num]]) + except IndexError: + raise IndexError( + f"Index {circuit_num} out of range for sweeps of size {len(sweeps)}." + ) + + # Not a batch job + if not is_batch and circuit_num and circuit_num != -1: + raise IndexError(f"Job is not a batch job, cannot index {circuit_num}") + return (reps, sweeps) def get_processor(self) -> engine_processor.EngineProcessor | None: """Returns the EngineProcessor for the processor the job is/was run on, diff --git a/cirq-google/cirq_google/engine/engine_job_test.py b/cirq-google/cirq_google/engine/engine_job_test.py index ecaa5939775..bdf89c282af 100644 --- a/cirq-google/cirq_google/engine/engine_job_test.py +++ b/cirq-google/cirq_google/engine/engine_job_test.py @@ -343,9 +343,7 @@ def test_get_repetitions_and_sweeps(get_job, get_program): _ = job.get_repetitions_and_sweeps(1) # Batch program, shared sweep - get_program.reset_mock() get_program.return_value = quantum.QuantumProgram(code=_BATCH_PROGRAM_V2) - get_job.reset_mock() get_job.return_value = quantum.QuantumJob( run_context=util.pack_any( v2.run_context_pb2.RunContext( @@ -359,7 +357,6 @@ def test_get_repetitions_and_sweeps(get_job, get_program): assert job_batch_shared.get_repetitions_and_sweeps(1) == (10, [cirq.UnitSweep]) # Batch program, mapped sweeps - get_job.reset_mock() get_job.return_value = quantum.QuantumJob( run_context=util.pack_any( v2.run_context_pb2.RunContext( diff --git a/cirq-google/cirq_google/engine/engine_program.py b/cirq-google/cirq_google/engine/engine_program.py index 7f6b42cb22e..66936b4bc31 100644 --- a/cirq-google/cirq_google/engine/engine_program.py +++ b/cirq-google/cirq_google/engine/engine_program.py @@ -365,7 +365,7 @@ async def _get_proto_async(self) -> v2.program_pb2.Program: async def is_batch_async(self) -> bool: """Returns True if the program is a batch program.""" proto = await self._get_proto_async() - return proto.WhichOneof('program') != 'circuit' + return len(proto.keyed_circuits) > 0 is_batch = duet.sync(is_batch_async) diff --git a/cirq-google/cirq_google/engine/engine_program_test.py b/cirq-google/cirq_google/engine/engine_program_test.py index 09ebaa5b4e3..51b0f028751 100644 --- a/cirq-google/cirq_google/engine/engine_program_test.py +++ b/cirq-google/cirq_google/engine/engine_program_test.py @@ -524,6 +524,7 @@ def test_get_circuit_error_cases(get_program_async): program_single = cg.EngineProgram('a', 'b', EngineContext()) get_program_async.reset_mock() get_program_async.return_value = quantum.QuantumProgram(code=_PROGRAM_V2) + assert program_single.get_circuit() is not None assert program_single.get_circuit(0) is not None assert program_single.get_circuit(-1) is not None with pytest.raises(IndexError, match="is not a batch program, cannot index 1"):