Skip to content
33 changes: 33 additions & 0 deletions tensorflow_quantum/core/ops/batch_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,39 @@ def test_no_circuit(self, sim):
self.assertDTypeEqual(results, np.int8)
self.assertEqual(np.zeros(shape=(0, 0, 0)).shape, results.shape)

def test_pauli_sum_collector_collect(self):
"""Test the collect method of TFQPauliSumCollector."""
qubit = cirq.GridQubit(0, 0)
circuit = cirq.Circuit(cirq.X(qubit))
samples_per_term = 100
sampler = cirq.Simulator()

# Case 1: Standard Pauli observable (Z). Expect -1.0.
observable1 = cirq.PauliSum.wrap(cirq.Z(qubit))
collector1 = batch_util.TFQPauliSumCollector(
circuit, observable1, samples_per_term=samples_per_term)
collector1.collect(sampler)

pauli_string1 = list(observable1)[0]
pauli_string1 = pauli_string1 / pauli_string1.coefficient
self.assertEqual(collector1._zeros[pauli_string1], 0)
self.assertEqual(collector1._ones[pauli_string1], samples_per_term)
self.assertAlmostEqual(collector1.estimated_energy(), -1.0)

# Case 2: Identity observable. Expect 1.0.
observable2 = cirq.PauliSum.wrap(cirq.I(qubit))
collector2 = batch_util.TFQPauliSumCollector(
circuit, observable2, samples_per_term=samples_per_term)
collector2.collect(sampler)
self.assertAlmostEqual(collector2.estimated_energy(), 1.0)

# Case 3: Mixed observable (Z + 2.0*I). Expect -1.0 + 2.0 = 1.0.
observable3 = cirq.Z(qubit) + 2.0 * cirq.I(qubit)
collector3 = batch_util.TFQPauliSumCollector(
circuit, observable3, samples_per_term=samples_per_term)
collector3.collect(sampler)
self.assertAlmostEqual(collector3.estimated_energy(), 1.0)
Comment thread
mhucka marked this conversation as resolved.
Outdated


if __name__ == '__main__':
tf.test.main()
Loading