Skip to content

Commit 9486e3f

Browse files
Add unit tests for TFQPauliSumCollector.collect (#1015)
This change implements unit tests for the `collect` method in `TFQPauliSumCollector`, addressing a testing gap in `batch_util_test.py`. The new tests cover: - Standard Pauli Z observable (expected energy -1.0) - Identity observable (expected energy 1.0) - Mixed observable (Z + 2.0*I) (expected energy 1.0) --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent c987980 commit 9486e3f

1 file changed

Lines changed: 22 additions & 0 deletions

File tree

tensorflow_quantum/core/ops/batch_util_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,28 @@ def test_no_circuit(self, sim):
308308
self.assertDTypeEqual(results, np.int8)
309309
self.assertEqual(np.zeros(shape=(0, 0, 0)).shape, results.shape)
310310

311+
def test_pauli_sum_collector_collect(self):
312+
"""Test the collect method of TFQPauliSumCollector."""
313+
qubit = cirq.GridQubit(0, 0)
314+
circuit = cirq.Circuit(cirq.X(qubit))
315+
samples_per_term = 100
316+
sampler = cirq.Simulator()
317+
318+
test_cases = [
319+
("Pauli observable (Z)", cirq.PauliSum.wrap(cirq.Z(qubit)), -1.0),
320+
("Identity observable", cirq.PauliSum.wrap(cirq.I(qubit)), 1.0),
321+
("Mixed observable (Z + 2.0*I)",
322+
cirq.Z(qubit) + 2.0 * cirq.I(qubit), 1.0),
323+
]
324+
325+
for name, observable, expected_energy in test_cases:
326+
with self.subTest(name):
327+
collector = batch_util.TFQPauliSumCollector(
328+
circuit, observable, samples_per_term=samples_per_term)
329+
collector.collect(sampler)
330+
self.assertAlmostEqual(collector.estimated_energy(),
331+
expected_energy)
332+
311333

312334
if __name__ == '__main__':
313335
tf.test.main()

0 commit comments

Comments
 (0)