diff --git a/MaxKernel/evaluation/.gitignore b/MaxKernel/evaluation/.gitignore index 91d69b9..4c4bcf8 100644 --- a/MaxKernel/evaluation/.gitignore +++ b/MaxKernel/evaluation/.gitignore @@ -1,4 +1,4 @@ .pytest_cache/ .env -*functional_test.py +*_test.py test_data/* diff --git a/MaxKernel/evaluation/code_adapter/code_adapter.py b/MaxKernel/evaluation/code_adapter/code_adapter.py index 979abef..08951b0 100644 --- a/MaxKernel/evaluation/code_adapter/code_adapter.py +++ b/MaxKernel/evaluation/code_adapter/code_adapter.py @@ -8,8 +8,7 @@ adapt_reference_prompt, ) from evaluation.custom_types.kernel_task import KernelTask - -GEMINI_MODEL = "gemini-2.5-flash" +from hitl_agent.constants import MODEL_NAME logging.basicConfig( level=logging.INFO, @@ -60,7 +59,7 @@ def adapt( while attempt < self.max_retries: try: response = self.client.models.generate_content( - model=GEMINI_MODEL, contents=prompt, config=config + model=MODEL_NAME, contents=prompt, config=config ) code = response.text.strip() if code.startswith("```python"): diff --git a/MaxKernel/evaluation/code_adapter/code_adapter_test.py b/MaxKernel/evaluation/code_adapter/code_adapter_test.py deleted file mode 100644 index fe3b523..0000000 --- a/MaxKernel/evaluation/code_adapter/code_adapter_test.py +++ /dev/null @@ -1,151 +0,0 @@ -import unittest -from unittest.mock import MagicMock, patch - -from google import genai - -from evaluation.code_adapter import code_adapter -from evaluation.custom_types.kernel_task import KernelTask - - -class TestCodeAdapter(unittest.TestCase): - def setUp(self): - self.mock_client = MagicMock(spec=genai.Client) - # Use 2 retries to make the test run faster - self.adapter = code_adapter.CodeAdapter( - client=self.mock_client, max_retries=2 - ) - - @patch.object(code_adapter.CodeAdapter, "_get_adapt_reference_prompt") - def test_adapt_reference_success(self, mock_prompt): - mock_prompt.return_value = "mock reference prompt" - mock_response = MagicMock() - # Test that it properly strips python markdown formatting - mock_response.text = "```python\n# Imports\nimport jax\n# Initialization\nx = 1\n# Computation\ndef comp(): pass\n```" - self.mock_client.models.generate_content.return_value = mock_response - - result = self.adapter.adapt("original code") - - self.assertEqual( - result, - "# Imports\nimport jax\n# Initialization\nx = 1\n# Computation\ndef comp(): pass", - ) - self.mock_client.models.generate_content.assert_called_once() - - @patch.object(code_adapter.CodeAdapter, "_get_adapt_optimized_prompt") - def test_adapt_optimized_success(self, mock_prompt): - mock_prompt.return_value = "mock optimized prompt" - mock_response = MagicMock() - # Test with no markdown backticks - mock_response.text = "# Imports\nimport jax\n# Initialization\nx = 1\n# Computation\ndef comp(): pass" - self.mock_client.models.generate_content.return_value = mock_response - - result = self.adapter.adapt( - "original code", - adapt_optimized=True, - get_inputs_code="def get_inputs(): pass", - ) - - self.assertEqual( - result, - "# Imports\nimport jax\n# Initialization\nx = 1\n# Computation\ndef comp(): pass", - ) - self.mock_client.models.generate_content.assert_called_once() - - def test_adapt_optimized_missing_get_inputs(self): - with self.assertRaisesRegex(ValueError, "get_inputs_code must be provided"): - self.adapter.adapt("original code", adapt_optimized=True) - - @patch.object(code_adapter.time, "sleep") - @patch.object(code_adapter.CodeAdapter, "_get_adapt_reference_prompt") - def test_adapt_retries_and_fails_on_missing_sections( - self, mock_prompt, mock_sleep - ): - mock_prompt.return_value = "mock prompt" - mock_response = MagicMock() - # Missing the required # Imports, # Initialization, # Computation sections - mock_response.text = "```python\ndef bad_format(): pass\n```" - self.mock_client.models.generate_content.return_value = mock_response - - with self.assertRaisesRegex(RuntimeError, "Failed to refactor code"): - self.adapter.adapt("original code") - - # max_retries = 2, so it should attempt 2 times - self.assertEqual(self.mock_client.models.generate_content.call_count, 2) - mock_sleep.assert_called() - - @patch.object(code_adapter.time, "sleep") - @patch.object(code_adapter.CodeAdapter, "_get_adapt_reference_prompt") - def test_adapt_retries_and_fails_on_exception(self, mock_prompt, mock_sleep): - mock_prompt.return_value = "mock prompt" - self.mock_client.models.generate_content.side_effect = Exception( - "API Error" - ) - - with self.assertRaisesRegex(RuntimeError, "Failed to refactor code"): - self.adapter.adapt("original code") - - self.assertEqual(self.mock_client.models.generate_content.call_count, 2) - mock_sleep.assert_called() - - def test_extract_input_gen_code_success(self): - sample_code = ( - "# Imports\n" - "import jax\n" - "import jax.numpy as jnp\n" - "# Initialization\n" - "BATCH = 8\n" - "\n" - "def get_inputs():\n" - " x = jnp.zeros(BATCH)\n" - " return [x], []\n" - "# Computation\n" - "def computation(x):\n" - " return x\n" - ) - expected_extracted = ( - "def get_inputs():\n" - " import jax\n" - " import jax.numpy as jnp\n" - "\n" - " BATCH = 8\n" - "\n" - " x = jnp.zeros(BATCH)\n" - " return [x], []" - ) - - result = self.adapter._extract_input_gen_code(sample_code) - self.assertEqual(result, expected_extracted) - - def test_extract_input_gen_code_missing_get_inputs(self): - sample_code = ( - "# Imports\n" - "import jax\n" - "# Initialization\n" - "BATCH = 8\n" - "# Computation\n" - "def computation(x):\n" - " return x\n" - ) - result = self.adapter._extract_input_gen_code(sample_code) - self.assertEqual(result, "") - - def test_generate_kernel_task(self): - sample_code = ( - "# Imports\n" - "import jax\n" - "# Initialization\n" - "def get_inputs():\n" - " return [], []\n" - "# Computation\n" - "def computation(): pass\n" - ) - - task = self.adapter.generate_kernel_task( - "test_id", "test desc", sample_code - ) - - self.assertIsInstance(task, KernelTask) - self.assertEqual(task.task_id, "test_id") - self.assertEqual(task.description, "test desc") - self.assertIn("def get_inputs():", task.input_gen_code) - self.assertIn("import jax", task.input_gen_code) diff --git a/MaxKernel/evaluation/evaluation_utils_test.py b/MaxKernel/evaluation/evaluation_utils_test.py deleted file mode 100644 index 1d9280e..0000000 --- a/MaxKernel/evaluation/evaluation_utils_test.py +++ /dev/null @@ -1,189 +0,0 @@ -import io -import os -import unittest -from contextlib import redirect_stdout -from unittest.mock import mock_open, patch - -from evaluation.custom_types.evaluation_result import EvaluationResult -from evaluation.custom_types.kernel_task import KernelTask -from evaluation.evaluation_utils import ( - load_kernel_task_from_yaml, - print_eval_result, - summarize_results, - write_kernel_task_to_yaml, -) - - -class TestEvaluationUtils(unittest.TestCase): - @patch("builtins.open", new_callable=mock_open) - def test_write_kernel_task_to_yaml(self, mock_file): - task = KernelTask( - task_id="test_task_1", - description="A test description.", - input_gen_code="def get_inputs():\n return [], []\n", - ) - - write_kernel_task_to_yaml(task, "dummy_path.yaml") - - mock_file.assert_called_once_with("dummy_path.yaml", "w", encoding="utf-8") - handle = mock_file() - # Assert that write was called; PyYAML makes multiple write calls, - # so we can combine them to check the final output string - written_content = "".join( - call.args[0] for call in handle.write.call_args_list - ) - self.assertIn("test_task_1", written_content) - - @patch("evaluation.evaluation_utils.os.path.exists", return_value=True) - def test_load_kernel_task_from_yaml(self, mock_exists): - mock_yaml_content = "task_id: test_task_1\ndescription: A test description.\ninput_gen_code: |\n def get_inputs():\n return [], []\n" - with patch("builtins.open", mock_open(read_data=mock_yaml_content)): - loaded_task = load_kernel_task_from_yaml("dummy_path.yaml") - - self.assertEqual(loaded_task.task_id, "test_task_1") - self.assertIn("def get_inputs():", loaded_task.input_gen_code) - - def test_load_missing_yaml(self): - with self.assertRaises(FileNotFoundError): - load_kernel_task_from_yaml("non_existent_file.yaml") - - def test_print_eval_result_success(self): - result = EvaluationResult( - task_id="test_1", - compiled_successfully=True, - numerically_correct=True, - max_abs_diff=1.0e-3, - max_rel_diff=1.0e-3, - reference_time_ms=10.0, - optimized_time_ms=5.0, - error_trace=None, - ) - - f = io.StringIO() - with redirect_stdout(f): - print_eval_result(result) - - output = f.getvalue() - self.assertIn("Correctness: [PASS]", output) - self.assertIn("Max Absolute Difference: 1.000000e-03", output) - self.assertIn("Max Relative Difference: 1.000000e-03", output) - self.assertIn("Reference time: 10.000 ms", output) - self.assertIn("Optimized time: 5.000 ms", output) - self.assertIn("Speedup: 2.00x", output) - - def test_print_eval_result_error(self): - result = EvaluationResult( - task_id="test_2", error_trace="Compilation failed due to syntax error." - ) - - f = io.StringIO() - with redirect_stdout(f): - print_eval_result(result) - - output = f.getvalue() - self.assertIn( - "Error: Compilation failed due to syntax error.", output - ) - - @patch("evaluation.evaluation_utils.logger.info") - def test_summarize_results(self, mock_logger_info): - results = [ - {"task_id": "1", "compiled_successfully": False}, - { - "task_id": "2", - "compiled_successfully": True, - "numerically_correct": False, - }, - { - "task_id": "3", - "compiled_successfully": True, - "numerically_correct": True, - "speedup": 2.0, - }, - { - "task_id": "4", - "compiled_successfully": True, - "numerically_correct": True, - "speedup": 0.5, - }, - { - "task_id": "5", - "compiled_successfully": True, - "numerically_correct": True, - "speedup": 8.0, - }, - ] - - summarize_results(results) - - log_messages = [call.args[0] for call in mock_logger_info.call_args_list] - log_text = "\n".join(log_messages) - - self.assertIn("- Attempted / Evaluated: 5", log_text) - self.assertIn("- Compiled Successfully: 4 (80.00%)", log_text) - self.assertIn("- Failed Compilation/Runtime: 1", log_text) - self.assertIn("- Numerically Correct: 3 (60.00%)", log_text) - self.assertIn("- Improved (Speedup > 1x): 2 (40.00%)", log_text) - self.assertIn("- Slower (Speedup < 1x): 1 (20.00%)", log_text) - self.assertIn("- Max Speedup: 8.00x", log_text) - self.assertIn("- Average Speedup/Slowdown (Arithmetic): 3.50x", log_text) - self.assertIn("- Average Speedup/Slowdown (Geometric): 2.00x", log_text) - - @patch("evaluation.evaluation_utils.logger.info") - def test_summarize_results_empty(self, mock_logger_info): - summarize_results([]) - mock_logger_info.assert_called_once_with( - "No results found to summarize and no tasks discovered." - ) - - @patch("evaluation.evaluation_utils.os.makedirs") - @patch("builtins.open", new_callable=mock_open) - @patch("evaluation.evaluation_utils.logger.info") - def test_summarize_results_with_output_dir( - self, mock_logger_info, mock_file, mock_makedirs - ): - results = [ - {"task_id": "1", "compiled_successfully": False}, - { - "task_id": "2", - "compiled_successfully": True, - "numerically_correct": False, - }, - { - "task_id": "3", - "compiled_successfully": True, - "numerically_correct": True, - "speedup": 2.0, - }, - { - "task_id": "4", - "compiled_successfully": True, - "numerically_correct": True, - "speedup": 0.5, - }, - { - "task_id": "5", - "compiled_successfully": True, - "numerically_correct": True, - "speedup": 8.0, - }, - ] - - summarize_results(results, output_dir="/fake/dir") - - mock_makedirs.assert_called_once_with("/fake/dir", exist_ok=True) - mock_file.assert_any_call( - os.path.join("/fake/dir", "summary.txt"), "w", encoding="utf-8" - ) - mock_file.assert_any_call( - os.path.join("/fake/dir", "summary.json"), "w", encoding="utf-8" - ) - - handle = mock_file() - written_content = "".join( - call.args[0] for call in handle.write.call_args_list - ) - self.assertIn('"total_attempted": 5', written_content) - self.assertIn('"compiled_successfully": 4', written_content) - self.assertIn('"numerically_correct": 3', written_content) - self.assertIn('"improved": 2', written_content) diff --git a/MaxKernel/evaluation/jax_kernel_evaluator_test.py b/MaxKernel/evaluation/jax_kernel_evaluator_test.py deleted file mode 100644 index 870c458..0000000 --- a/MaxKernel/evaluation/jax_kernel_evaluator_test.py +++ /dev/null @@ -1,290 +0,0 @@ -import json -import subprocess -import unittest -from unittest.mock import MagicMock, mock_open, patch - -from evaluation.custom_types.kernel_task import KernelTask -from evaluation.jax_kernel_evaluator import JAXKernelEvaluator - - -class TestJAXKernelEvaluator(unittest.TestCase): - @patch("evaluation.jax_kernel_evaluator.TPUVMClient") - def test_init_remote_success(self, mock_client): - evaluator = JAXKernelEvaluator( - local=False, - tpu_name="test-tpu", - project="test-proj", - zone="us-central1-a", - ) - self.assertFalse(evaluator.local) - mock_client.assert_called_once_with( - project="test-proj", zone="us-central1-a", tpu_name="test-tpu" - ) - - def test_init_remote_missing_args(self): - with self.assertRaises(ValueError): - JAXKernelEvaluator(local=False) - - def test_init_local(self): - evaluator = JAXKernelEvaluator(local=True) - self.assertTrue(evaluator.local) - self.assertIsNone(evaluator.client) - - @patch.object(JAXKernelEvaluator, "_evaluate_local") - def test_evaluate_routing_local(self, mock_eval_local): - evaluator = JAXKernelEvaluator(local=True) - evaluator.evaluate("ref.py", "opt.py", "task.yaml") - mock_eval_local.assert_called_once_with( - "ref.py", "opt.py", "task.yaml", None, 300, True, 1e-3, 1e-3 - ) - - @patch.object(JAXKernelEvaluator, "_evaluate_remote") - @patch("evaluation.jax_kernel_evaluator.TPUVMClient") - def test_evaluate_routing_remote(self, mock_client, mock_eval_remote): - evaluator = JAXKernelEvaluator( - local=False, tpu_name="t", project="p", zone="z" - ) - evaluator.evaluate("ref.py", "opt.py", "task.yaml") - mock_eval_remote.assert_called_once_with( - "ref.py", "opt.py", "task.yaml", None, 300, True, 1e-3, 1e-3 - ) - - @patch("evaluation.jax_kernel_evaluator.load_kernel_task_from_yaml") - @patch("evaluation.jax_kernel_evaluator.shutil") - @patch("evaluation.jax_kernel_evaluator.tempfile") - @patch("evaluation.jax_kernel_evaluator.subprocess.run") - @patch("os.path.exists") - @patch("evaluation.jax_kernel_evaluator.print_eval_result") - def test_evaluate_local_success( - self, - mock_print, - mock_exists, - mock_subproc, - mock_tempfile, - mock_shutil, - mock_load_task, - ): - evaluator = JAXKernelEvaluator(local=True) - mock_load_task.return_value = KernelTask( - task_id="test_task", description="", input_gen_code="" - ) - mock_tempfile.mkdtemp.return_value = "/fake/dir" - mock_exists.return_value = True # for os.path.exists(result_local) - - result_data = { - "compiled_successfully": True, - "numerically_correct": True, - "reference_time_ms": 10.0, - "optimized_time_ms": 5.0, - "xprof_reference_time_ms": 8.0, - "xprof_optimized_time_ms": 4.0, - } - - with patch("builtins.open", mock_open(read_data=json.dumps(result_data))): - with ( - patch.object(evaluator, "_build_task_json"), - patch.object(evaluator, "_build_harness_code"), - ): - result = evaluator._evaluate_local("ref.py", "opt.py", "task.yaml") - - self.assertTrue(result.compiled_successfully) - self.assertTrue(result.numerically_correct) - self.assertEqual(result.speedup, 2.0) - self.assertEqual(result.xprof_reference_time_ms, 8.0) - self.assertEqual(result.xprof_optimized_time_ms, 4.0) - mock_subproc.assert_called_once() - mock_shutil.rmtree.assert_called_once_with("/fake/dir", ignore_errors=True) - - @patch("evaluation.jax_kernel_evaluator.load_kernel_task_from_yaml") - @patch("evaluation.jax_kernel_evaluator.shutil") - @patch("evaluation.jax_kernel_evaluator.tempfile") - @patch("evaluation.jax_kernel_evaluator.subprocess.run") - def test_evaluate_local_timeout( - self, mock_subproc, mock_tempfile, mock_shutil, mock_load_task - ): - evaluator = JAXKernelEvaluator(local=True) - mock_load_task.return_value = KernelTask( - task_id="test_task", description="", input_gen_code="" - ) - mock_tempfile.mkdtemp.return_value = "/fake/dir" - mock_subproc.side_effect = subprocess.TimeoutExpired( - cmd="python3 harness.py", timeout=300 - ) - - with ( - patch("builtins.open", mock_open()), - patch.object(evaluator, "_build_task_json"), - patch.object(evaluator, "_build_harness_code"), - ): - result = evaluator._evaluate_local( - "ref.py", "opt.py", "task.yaml", timeout_seconds=300 - ) - - self.assertIn("timed out", result.error_trace) - - @patch("evaluation.jax_kernel_evaluator.load_kernel_task_from_yaml") - @patch("evaluation.jax_kernel_evaluator.shutil") - @patch("evaluation.jax_kernel_evaluator.tempfile") - @patch("evaluation.jax_kernel_evaluator.subprocess.run") - def test_evaluate_local_crash( - self, mock_subproc, mock_tempfile, mock_shutil, mock_load_task - ): - evaluator = JAXKernelEvaluator(local=True) - mock_load_task.return_value = KernelTask( - task_id="test_task", description="", input_gen_code="" - ) - mock_tempfile.mkdtemp.return_value = "/fake/dir" - mock_subproc.side_effect = subprocess.CalledProcessError( - returncode=-11, cmd="python3 harness.py", stderr="Segmentation fault" - ) - - with ( - patch("builtins.open", mock_open()), - patch.object(evaluator, "_build_task_json"), - patch.object(evaluator, "_build_harness_code"), - ): - result = evaluator._evaluate_local("ref.py", "opt.py", "task.yaml") - - self.assertIn("Command failed with exit code -11", result.error_trace) - - def test_evaluate_local_missing_task_yaml(self): - evaluator = JAXKernelEvaluator(local=True) - - # Test with adapt but missing task output - with patch.object(evaluator, "_adapt_inputs") as mock_adapt: - mock_adapt.return_value = ("ref.py", "opt.py", None) - with self.assertRaises(ValueError) as context: - evaluator._evaluate_local( - "ref.py", "opt.py", task_yaml_path=None, adapt=["reference_code"] - ) - self.assertIn("task_yaml_path is required", str(context.exception)) - - # Test without adapt - with self.assertRaises(ValueError) as context: - evaluator._evaluate_local("ref.py", "opt.py", task_yaml_path=None) - self.assertIn("task_yaml_path is required", str(context.exception)) - - @patch("evaluation.jax_kernel_evaluator.load_kernel_task_from_yaml") - @patch("evaluation.jax_kernel_evaluator.tempfile") - @patch("evaluation.jax_kernel_evaluator.os") - @patch("evaluation.jax_kernel_evaluator.print_eval_result") - @patch("evaluation.jax_kernel_evaluator.TPUVMClient") - def test_evaluate_remote_success( - self, mock_client_cls, mock_print, mock_os, mock_tempfile, mock_load_task - ): - mock_client_instance = mock_client_cls.return_value - mock_client_instance.tpu_name = "test-tpu" - - mock_cat_result = MagicMock() - mock_cat_result.stdout = '{"compiled_successfully": true, "numerically_correct": true, "reference_time_ms": 15.0, "optimized_time_ms": 10.0, "xprof_reference_time_ms": 12.0, "xprof_optimized_time_ms": 8.0}' - mock_client_instance.execute_ssh_command.side_effect = [ - MagicMock(), # mkdir - MagicMock(), # run script - mock_cat_result, # cat result.json - MagicMock(), # rm -rf (cleanup) - ] - - evaluator = JAXKernelEvaluator( - local=False, tpu_name="t", project="p", zone="z" - ) - evaluator.client = mock_client_instance - - mock_load_task.return_value = KernelTask( - task_id="test_task", description="", input_gen_code="" - ) - mock_tempfile.mkstemp.side_effect = [ - (1, "/tmp/harness.py"), - (2, "/tmp/task.json"), - ] - mock_os.path.exists.return_value = True - - with ( - patch.object(evaluator, "_build_task_json"), - patch.object(evaluator, "_build_harness_code"), - ): - result = evaluator._evaluate_remote("ref.py", "opt.py", "task.yaml") - - self.assertTrue(result.compiled_successfully) - self.assertTrue(result.numerically_correct) - self.assertEqual(result.speedup, 1.5) - self.assertEqual(result.xprof_reference_time_ms, 12.0) - self.assertEqual(result.xprof_optimized_time_ms, 8.0) - self.assertEqual(mock_client_instance.execute_ssh_command.call_count, 4) - mock_client_instance.upload_file.assert_called() - - @patch("evaluation.jax_kernel_evaluator.load_kernel_task_from_yaml") - @patch("evaluation.jax_kernel_evaluator.tempfile") - @patch("evaluation.jax_kernel_evaluator.os") - @patch("evaluation.jax_kernel_evaluator.print_eval_result") - @patch("evaluation.jax_kernel_evaluator.TPUVMClient") - def test_evaluate_remote_timeout( - self, mock_client_cls, mock_print, mock_os, mock_tempfile, mock_load_task - ): - mock_client_instance = mock_client_cls.return_value - mock_client_instance.tpu_name = "test-tpu" - - def side_effect(*args, **kwargs): - if "harness.py" in args[0] and "pkill" not in args[0]: - raise subprocess.TimeoutExpired(cmd=args[0], timeout=300) - return MagicMock() - - mock_client_instance.execute_ssh_command.side_effect = side_effect - - evaluator = JAXKernelEvaluator( - local=False, tpu_name="t", project="p", zone="z" - ) - evaluator.client = mock_client_instance - - mock_load_task.return_value = KernelTask( - task_id="test_task", description="", input_gen_code="" - ) - mock_tempfile.mkstemp.side_effect = [(1, "/tmp/h.py"), (2, "/tmp/t.json")] - - with ( - patch.object(evaluator, "_build_task_json"), - patch.object(evaluator, "_build_harness_code"), - patch("evaluation.jax_kernel_evaluator.time.sleep"), - ): # Mock sleep to avoid delay - result = evaluator._evaluate_remote("ref.py", "opt.py", "task.yaml") - - self.assertIn("timed out", result.error_trace) - # Verify pkill was called to cleanup the runaway process - pkill_calls = [ - call - for call in mock_client_instance.execute_ssh_command.call_args_list - if "pkill" in call.args[0] - ] - self.assertTrue(len(pkill_calls) > 0) - - @patch("evaluation.jax_kernel_evaluator.load_kernel_task_from_yaml") - @patch("evaluation.jax_kernel_evaluator.tempfile") - @patch("evaluation.jax_kernel_evaluator.os") - @patch("evaluation.jax_kernel_evaluator.print_eval_result") - @patch("evaluation.jax_kernel_evaluator.TPUVMClient") - def test_evaluate_remote_runtime_error( - self, mock_client_cls, mock_print, mock_os, mock_tempfile, mock_load_task - ): - mock_client_instance = mock_client_cls.return_value - - def side_effect(*args, **kwargs): - if "harness.py" in args[0]: - raise RuntimeError("Segmentation fault (core dumped)") - return MagicMock() - - mock_client_instance.execute_ssh_command.side_effect = side_effect - evaluator = JAXKernelEvaluator( - local=False, tpu_name="t", project="p", zone="z" - ) - evaluator.client = mock_client_instance - mock_load_task.return_value = KernelTask( - task_id="test_task", description="", input_gen_code="" - ) - mock_tempfile.mkstemp.side_effect = [(1, "/tmp/h.py"), (2, "/tmp/t.json")] - - with ( - patch.object(evaluator, "_build_task_json"), - patch.object(evaluator, "_build_harness_code"), - ): - result = evaluator._evaluate_remote("ref.py", "opt.py", "task.yaml") - - self.assertIn("Segmentation fault", result.error_trace) diff --git a/MaxKernel/evaluation/remote_client/tpu_client_test.py b/MaxKernel/evaluation/remote_client/tpu_client_test.py deleted file mode 100644 index 8e620a7..0000000 --- a/MaxKernel/evaluation/remote_client/tpu_client_test.py +++ /dev/null @@ -1,148 +0,0 @@ -import subprocess -import unittest -from unittest.mock import MagicMock, patch - -from google.auth.exceptions import DefaultCredentialsError - -from evaluation.remote_client.tpu_client import ( - TPUVMClient, - run_script_on_tpu_vm, -) - - -class TestTPUVMClient(unittest.TestCase): - def setUp(self): - self.auth_patcher = patch("remote_client.tpu_client.google.auth.default") - self.mock_auth = self.auth_patcher.start() - self.addCleanup(self.auth_patcher.stop) - self.mock_auth.return_value = (MagicMock(), "test-proj") - - def test_authenticate_success(self): - self.mock_auth.return_value = (MagicMock(), "default-project") - client = TPUVMClient(zone="us-east1-d", tpu_name="my-tpu") - self.assertEqual(client.project, "default-project") - self.assertEqual(client.zone, "us-east1-d") - self.assertEqual(client.tpu_name, "my-tpu") - - def test_authenticate_provided_project(self): - self.mock_auth.return_value = (MagicMock(), "default-project") - client = TPUVMClient( - project="my-project", zone="us-east1-d", tpu_name="my-tpu" - ) - self.assertEqual(client.project, "my-project") - - def test_authenticate_failure(self): - self.mock_auth.side_effect = DefaultCredentialsError("No creds") - with self.assertRaises(DefaultCredentialsError): - TPUVMClient() - - @patch("evaluation.remote_client.tpu_client.subprocess.run") - def test_run_gcloud_success(self, mock_subprocess_run): - client = TPUVMClient(zone="test-zone", tpu_name="test-tpu") - - mock_result = MagicMock() - mock_subprocess_run.return_value = mock_result - - res = client._run_gcloud(["arg1", "arg2"]) - - self.assertEqual(res, mock_result) - mock_subprocess_run.assert_called_once_with( - [ - "gcloud", - "compute", - "tpus", - "tpu-vm", - "arg1", - "arg2", - "--zone=test-zone", - "--project=test-proj", - ], - check=True, - capture_output=True, - text=True, - timeout=None, - ) - - @patch("evaluation.remote_client.tpu_client.subprocess.run") - def test_run_gcloud_failure(self, mock_subprocess_run): - client = TPUVMClient(zone="test-zone", tpu_name="test-tpu") - - mock_subprocess_run.side_effect = subprocess.CalledProcessError( - returncode=1, cmd="cmd", stderr="error message" - ) - with self.assertRaisesRegex( - RuntimeError, "Failed to execute gcloud command" - ): - client._run_gcloud(["arg"]) - - @patch("evaluation.remote_client.tpu_client.os.path.exists") - @patch.object(TPUVMClient, "_run_gcloud") - def test_upload_file_success(self, mock_run_gcloud, mock_exists): - mock_exists.return_value = True - client = TPUVMClient(zone="test-zone", tpu_name="test-tpu") - - client.upload_file("local_script.py", "remote_script.py") - mock_run_gcloud.assert_called_once_with( - ["scp", "local_script.py", "test-tpu:remote_script.py"] - ) - - @patch("evaluation.remote_client.tpu_client.os.path.exists") - def test_upload_file_not_found(self, mock_exists): - mock_exists.return_value = False - client = TPUVMClient(zone="test-zone", tpu_name="test-tpu") - - with self.assertRaises(FileNotFoundError): - client.upload_file("missing.py", "remote.py") - - @patch.object(TPUVMClient, "_run_gcloud") - def test_execute_ssh_command(self, mock_run_gcloud): - client = TPUVMClient(zone="test-zone", tpu_name="test-tpu") - - client.execute_ssh_command("ls -la", timeout=15) - mock_run_gcloud.assert_called_once_with( - ["ssh", "test-tpu", "--command=ls -la"], timeout=15 - ) - - @patch.object(TPUVMClient, "execute_ssh_command") - def test_delete_file(self, mock_ssh): - client = TPUVMClient(zone="test-zone", tpu_name="test-tpu") - - client.delete_file("remote_script.py") - mock_ssh.assert_called_once_with("rm -f remote_script.py") - - def test_quote_path(self): - self.assertEqual(TPUVMClient.quote_path("/path/to/file"), "/path/to/file") - self.assertEqual(TPUVMClient.quote_path("~/my_file.py"), "$HOME/my_file.py") - self.assertEqual(TPUVMClient.quote_path("my file.py"), "'my file.py'") - - -class TestRunScriptOnTPUVM(unittest.TestCase): - @patch("evaluation.remote_client.tpu_client.TPUVMClient") - def test_run_script_success(self, mock_client_class): - mock_client_class.quote_path.side_effect = lambda x: x - mock_client = mock_client_class.return_value - mock_result = MagicMock() - mock_result.stdout = "output" - mock_result.stderr = "" - mock_client.execute_ssh_command.return_value = mock_result - - run_script_on_tpu_vm( - local_script_path="local.py", - tpu_name="test-tpu", - zone="test-zone", - venv_path="/path/to/venv", - script_args=["arg1", "arg2"], - cleanup_script=True, - ) - - mock_client.upload_file.assert_called_once_with( - "local.py", "uploaded_local.py" - ) - mock_client.execute_ssh_command.assert_called_once() - - cmd = mock_client.execute_ssh_command.call_args[0][0] - self.assertIn("source /path/to/venv/bin/activate", cmd) - self.assertIn("python3 uploaded_local.py", cmd) - self.assertIn("arg1 arg2", cmd) - - mock_client.delete_file.assert_called_once_with("uploaded_local.py")