From f00d1975bb032cc6f93341823b876cd54309c491 Mon Sep 17 00:00:00 2001 From: mhucka <1450019+mhucka@users.noreply.github.com> Date: Fri, 27 Mar 2026 05:27:25 +0000 Subject: [PATCH 1/3] =?UTF-8?q?=F0=9F=A7=AA=20[testing=20improvement]=20Ad?= =?UTF-8?q?d=20tests=20for=20random=5Fsymbol=5Fcircuit=5Fresolver=5Fbatch?= =?UTF-8?q?=20and=20random=5Fcircuit=5Fresolver=5Fbatch?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added comprehensive unit tests to `tensorflow_quantum/python/util_test.py` to verify the output shapes and types of `random_symbol_circuit_resolver_batch` and `random_circuit_resolver_batch`. These tests ensure that the batch generators return the correct number of `cirq.Circuit` and `cirq.ParamResolver` objects, and that symbols are correctly present in the resolvers when applicable. --- tensorflow_quantum/python/util_test.py | 31 ++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tensorflow_quantum/python/util_test.py b/tensorflow_quantum/python/util_test.py index 00715899b..3edfa92b1 100644 --- a/tensorflow_quantum/python/util_test.py +++ b/tensorflow_quantum/python/util_test.py @@ -428,6 +428,37 @@ def test_get_circuit_symbols_error(self): 'cirq.Circuit'): util.get_circuit_symbols(param) + def test_random_circuit_resolver_batch(self): + """Confirm that random_circuit_resolver_batch works.""" + qubits = cirq.GridQubit.rect(1, 2) + batch_size = 5 + circuits, resolvers = util.random_circuit_resolver_batch( + qubits, batch_size) + self.assertEqual(len(circuits), batch_size) + self.assertEqual(len(resolvers), batch_size) + for circuit in circuits: + self.assertIsInstance(circuit, cirq.Circuit) + for resolver in resolvers: + self.assertIsInstance(resolver, cirq.ParamResolver) + self.assertEqual(len(resolver.param_dict), 0) + + def test_random_symbol_circuit_resolver_batch(self): + """Confirm that random_symbol_circuit_resolver_batch works.""" + qubits = cirq.GridQubit.rect(1, 2) + symbols = [sympy.Symbol('a'), sympy.Symbol('b')] + batch_size = 5 + circuits, resolvers = util.random_symbol_circuit_resolver_batch( + qubits, symbols, batch_size) + self.assertEqual(len(circuits), batch_size) + self.assertEqual(len(resolvers), batch_size) + for circuit in circuits: + self.assertIsInstance(circuit, cirq.Circuit) + for resolver in resolvers: + self.assertIsInstance(resolver, cirq.ParamResolver) + self.assertEqual(len(resolver.param_dict), len(symbols)) + for symbol in symbols: + self.assertIn(symbol, resolver.param_dict) + class ExponentialUtilFunctionsTest(tf.test.TestCase): """Test that Exponential utility functions work.""" From d341ee0667503663b4dfccd0fc90feebed0403e7 Mon Sep 17 00:00:00 2001 From: Michael Hucka Date: Fri, 3 Apr 2026 22:29:01 -0700 Subject: [PATCH 2/3] Update tensorflow_quantum/python/util_test.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- tensorflow_quantum/python/util_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow_quantum/python/util_test.py b/tensorflow_quantum/python/util_test.py index 3edfa92b1..3253bbf61 100644 --- a/tensorflow_quantum/python/util_test.py +++ b/tensorflow_quantum/python/util_test.py @@ -438,6 +438,8 @@ def test_random_circuit_resolver_batch(self): self.assertEqual(len(resolvers), batch_size) for circuit in circuits: self.assertIsInstance(circuit, cirq.Circuit) + self.assertGreater(len(circuit), 0, "Generated circuit should not be empty.") + self.assertEqual(len(util.get_circuit_symbols(circuit)), 0, "Circuit should not have symbols.") for resolver in resolvers: self.assertIsInstance(resolver, cirq.ParamResolver) self.assertEqual(len(resolver.param_dict), 0) From 6c8c4f13d559b9e9add44568e2041889a40a599f Mon Sep 17 00:00:00 2001 From: Michael Hucka Date: Fri, 3 Apr 2026 22:29:19 -0700 Subject: [PATCH 3/3] Update tensorflow_quantum/python/util_test.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- tensorflow_quantum/python/util_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tensorflow_quantum/python/util_test.py b/tensorflow_quantum/python/util_test.py index 3253bbf61..365bb18ef 100644 --- a/tensorflow_quantum/python/util_test.py +++ b/tensorflow_quantum/python/util_test.py @@ -455,6 +455,9 @@ def test_random_symbol_circuit_resolver_batch(self): self.assertEqual(len(resolvers), batch_size) for circuit in circuits: self.assertIsInstance(circuit, cirq.Circuit) + extracted_symbols = util.get_circuit_symbols(circuit) + expected_symbols = sorted([str(s) for s in symbols]) + self.assertListEqual(expected_symbols, sorted(extracted_symbols)) for resolver in resolvers: self.assertIsInstance(resolver, cirq.ParamResolver) self.assertEqual(len(resolver.param_dict), len(symbols))