-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathcode_adapter_test.py
More file actions
151 lines (128 loc) · 5.05 KB
/
code_adapter_test.py
File metadata and controls
151 lines (128 loc) · 5.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import unittest
from unittest.mock import MagicMock, patch
from google import genai
from evaluation.code_adapter import code_adapter
from evaluation.custom_types.kernel_task import KernelTask
class TestCodeAdapter(unittest.TestCase):
def setUp(self):
self.mock_client = MagicMock(spec=genai.Client)
# Use 2 retries to make the test run faster
self.adapter = code_adapter.CodeAdapter(
client=self.mock_client, max_retries=2
)
@patch.object(code_adapter.CodeAdapter, "_get_adapt_reference_prompt")
def test_adapt_reference_success(self, mock_prompt):
mock_prompt.return_value = "mock reference prompt"
mock_response = MagicMock()
# Test that it properly strips python markdown formatting
mock_response.text = "```python\n# Imports\nimport jax\n# Initialization\nx = 1\n# Computation\ndef comp(): pass\n```"
self.mock_client.models.generate_content.return_value = mock_response
result = self.adapter.adapt("original code")
self.assertEqual(
result,
"# Imports\nimport jax\n# Initialization\nx = 1\n# Computation\ndef comp(): pass",
)
self.mock_client.models.generate_content.assert_called_once()
@patch.object(code_adapter.CodeAdapter, "_get_adapt_optimized_prompt")
def test_adapt_optimized_success(self, mock_prompt):
mock_prompt.return_value = "mock optimized prompt"
mock_response = MagicMock()
# Test with no markdown backticks
mock_response.text = "# Imports\nimport jax\n# Initialization\nx = 1\n# Computation\ndef comp(): pass"
self.mock_client.models.generate_content.return_value = mock_response
result = self.adapter.adapt(
"original code",
adapt_optimized=True,
get_inputs_code="def get_inputs(): pass",
)
self.assertEqual(
result,
"# Imports\nimport jax\n# Initialization\nx = 1\n# Computation\ndef comp(): pass",
)
self.mock_client.models.generate_content.assert_called_once()
def test_adapt_optimized_missing_get_inputs(self):
with self.assertRaisesRegex(ValueError, "get_inputs_code must be provided"):
self.adapter.adapt("original code", adapt_optimized=True)
@patch.object(code_adapter.time, "sleep")
@patch.object(code_adapter.CodeAdapter, "_get_adapt_reference_prompt")
def test_adapt_retries_and_fails_on_missing_sections(
self, mock_prompt, mock_sleep
):
mock_prompt.return_value = "mock prompt"
mock_response = MagicMock()
# Missing the required # Imports, # Initialization, # Computation sections
mock_response.text = "```python\ndef bad_format(): pass\n```"
self.mock_client.models.generate_content.return_value = mock_response
with self.assertRaisesRegex(RuntimeError, "Failed to refactor code"):
self.adapter.adapt("original code")
# max_retries = 2, so it should attempt 2 times
self.assertEqual(self.mock_client.models.generate_content.call_count, 2)
mock_sleep.assert_called()
@patch.object(code_adapter.time, "sleep")
@patch.object(code_adapter.CodeAdapter, "_get_adapt_reference_prompt")
def test_adapt_retries_and_fails_on_exception(self, mock_prompt, mock_sleep):
mock_prompt.return_value = "mock prompt"
self.mock_client.models.generate_content.side_effect = Exception(
"API Error"
)
with self.assertRaisesRegex(RuntimeError, "Failed to refactor code"):
self.adapter.adapt("original code")
self.assertEqual(self.mock_client.models.generate_content.call_count, 2)
mock_sleep.assert_called()
def test_extract_input_gen_code_success(self):
sample_code = (
"# Imports\n"
"import jax\n"
"import jax.numpy as jnp\n"
"# Initialization\n"
"BATCH = 8\n"
"\n"
"def get_inputs():\n"
" x = jnp.zeros(BATCH)\n"
" return [x], []\n"
"# Computation\n"
"def computation(x):\n"
" return x\n"
)
expected_extracted = (
"def get_inputs():\n"
" import jax\n"
" import jax.numpy as jnp\n"
"\n"
" BATCH = 8\n"
"\n"
" x = jnp.zeros(BATCH)\n"
" return [x], []"
)
result = self.adapter._extract_input_gen_code(sample_code)
self.assertEqual(result, expected_extracted)
def test_extract_input_gen_code_missing_get_inputs(self):
sample_code = (
"# Imports\n"
"import jax\n"
"# Initialization\n"
"BATCH = 8\n"
"# Computation\n"
"def computation(x):\n"
" return x\n"
)
result = self.adapter._extract_input_gen_code(sample_code)
self.assertEqual(result, "")
def test_generate_kernel_task(self):
sample_code = (
"# Imports\n"
"import jax\n"
"# Initialization\n"
"def get_inputs():\n"
" return [], []\n"
"# Computation\n"
"def computation(): pass\n"
)
task = self.adapter.generate_kernel_task(
"test_id", "test desc", sample_code
)
self.assertIsInstance(task, KernelTask)
self.assertEqual(task.task_id, "test_id")
self.assertEqual(task.description, "test desc")
self.assertIn("def get_inputs():", task.input_gen_code)
self.assertIn("import jax", task.input_gen_code)