|
| 1 | +"""Tests for automation.generators.plot_generator module.""" |
| 2 | + |
| 3 | +from unittest.mock import MagicMock, patch |
| 4 | + |
| 5 | +import pytest |
| 6 | + |
| 7 | +from automation.generators.plot_generator import extract_and_validate_code, retry_with_backoff |
| 8 | + |
| 9 | + |
| 10 | +class TestExtractAndValidateCode: |
| 11 | + """Tests for extract_and_validate_code function.""" |
| 12 | + |
| 13 | + def test_extract_plain_code(self): |
| 14 | + response = """import matplotlib.pyplot as plt |
| 15 | +plt.plot([1, 2, 3]) |
| 16 | +plt.savefig('plot.png')""" |
| 17 | + |
| 18 | + result = extract_and_validate_code(response) |
| 19 | + |
| 20 | + assert "import matplotlib" in result |
| 21 | + assert "plt.plot" in result |
| 22 | + assert "plt.savefig" in result |
| 23 | + |
| 24 | + def test_extract_from_markdown_python(self): |
| 25 | + response = """Here is the implementation: |
| 26 | +
|
| 27 | +```python |
| 28 | +import matplotlib.pyplot as plt |
| 29 | +import numpy as np |
| 30 | +
|
| 31 | +np.random.seed(42) |
| 32 | +x = np.random.randn(100) |
| 33 | +plt.scatter(x, x * 0.5) |
| 34 | +plt.savefig('plot.png') |
| 35 | +``` |
| 36 | +
|
| 37 | +This creates a simple scatter plot.""" |
| 38 | + |
| 39 | + result = extract_and_validate_code(response) |
| 40 | + |
| 41 | + assert "import matplotlib" in result |
| 42 | + assert "import numpy" in result |
| 43 | + assert "np.random.seed(42)" in result |
| 44 | + assert "Here is the implementation" not in result |
| 45 | + assert "This creates" not in result |
| 46 | + |
| 47 | + def test_extract_from_generic_markdown(self): |
| 48 | + response = """``` |
| 49 | +import numpy as np |
| 50 | +x = np.array([1, 2, 3]) |
| 51 | +print(x) |
| 52 | +```""" |
| 53 | + |
| 54 | + result = extract_and_validate_code(response) |
| 55 | + |
| 56 | + assert "import numpy" in result |
| 57 | + assert "np.array" in result |
| 58 | + |
| 59 | + def test_extract_code_with_multiple_blocks(self): |
| 60 | + """Should extract from first python block.""" |
| 61 | + response = """Here's an example: |
| 62 | +
|
| 63 | +```python |
| 64 | +import matplotlib.pyplot as plt |
| 65 | +plt.plot([1, 2, 3]) |
| 66 | +``` |
| 67 | +
|
| 68 | +And here's another: |
| 69 | +
|
| 70 | +```python |
| 71 | +print("second block") |
| 72 | +``` |
| 73 | +""" |
| 74 | + |
| 75 | + result = extract_and_validate_code(response) |
| 76 | + |
| 77 | + assert "import matplotlib" in result |
| 78 | + assert "plt.plot" in result |
| 79 | + # Should only get first block |
| 80 | + assert "second block" not in result |
| 81 | + |
| 82 | + def test_empty_code_raises_value_error(self): |
| 83 | + with pytest.raises(ValueError, match="No code could be extracted"): |
| 84 | + extract_and_validate_code("") |
| 85 | + |
| 86 | + def test_whitespace_only_raises_value_error(self): |
| 87 | + with pytest.raises(ValueError, match="No code could be extracted"): |
| 88 | + extract_and_validate_code(" \n\n ") |
| 89 | + |
| 90 | + def test_empty_code_block_raises_value_error(self): |
| 91 | + response = """```python |
| 92 | +```""" |
| 93 | + |
| 94 | + with pytest.raises(ValueError, match="No code could be extracted"): |
| 95 | + extract_and_validate_code(response) |
| 96 | + |
| 97 | + def test_syntax_error_raises_value_error(self): |
| 98 | + response = """```python |
| 99 | +def broken( |
| 100 | + print("missing closing paren" |
| 101 | +```""" |
| 102 | + |
| 103 | + with pytest.raises(ValueError, match="syntax errors"): |
| 104 | + extract_and_validate_code(response) |
| 105 | + |
| 106 | + def test_indentation_error_raises(self): |
| 107 | + response = """```python |
| 108 | +def foo(): |
| 109 | +print("bad indent") |
| 110 | +```""" |
| 111 | + |
| 112 | + with pytest.raises(ValueError, match="syntax errors"): |
| 113 | + extract_and_validate_code(response) |
| 114 | + |
| 115 | + def test_valid_complex_code(self): |
| 116 | + response = """```python |
| 117 | +import matplotlib.pyplot as plt |
| 118 | +import numpy as np |
| 119 | +from typing import Optional |
| 120 | +
|
| 121 | +def create_plot(title: Optional[str] = None) -> None: |
| 122 | + np.random.seed(42) |
| 123 | + x = np.random.randn(100) |
| 124 | + y = x * 0.8 + np.random.randn(100) * 0.5 |
| 125 | +
|
| 126 | + fig, ax = plt.subplots(figsize=(16, 9)) |
| 127 | + ax.scatter(x, y, alpha=0.7) |
| 128 | +
|
| 129 | + if title: |
| 130 | + ax.set_title(title) |
| 131 | +
|
| 132 | + plt.savefig('plot.png', dpi=300) |
| 133 | +
|
| 134 | +if __name__ == '__main__': |
| 135 | + create_plot('Scatter Plot') |
| 136 | +```""" |
| 137 | + |
| 138 | + result = extract_and_validate_code(response) |
| 139 | + |
| 140 | + assert "np.random.seed(42)" in result |
| 141 | + assert "figsize=(16, 9)" in result |
| 142 | + assert "def create_plot" in result |
| 143 | + assert "Optional[str]" in result |
| 144 | + |
| 145 | + def test_code_with_comments_and_docstrings(self): |
| 146 | + response = '''```python |
| 147 | +"""Module docstring.""" |
| 148 | +
|
| 149 | +import matplotlib.pyplot as plt |
| 150 | +
|
| 151 | +# Create a simple plot |
| 152 | +def plot_data(): |
| 153 | + """Create and save a plot.""" |
| 154 | + plt.plot([1, 2, 3]) # inline comment |
| 155 | + plt.savefig("output.png") |
| 156 | +```''' |
| 157 | + |
| 158 | + result = extract_and_validate_code(response) |
| 159 | + |
| 160 | + assert '"""Module docstring."""' in result |
| 161 | + assert "# Create a simple plot" in result |
| 162 | + assert "# inline comment" in result |
| 163 | + |
| 164 | + def test_code_with_f_strings(self): |
| 165 | + response = """```python |
| 166 | +name = "test" |
| 167 | +value = 42 |
| 168 | +print(f"Name: {name}, Value: {value}") |
| 169 | +```""" |
| 170 | + |
| 171 | + result = extract_and_validate_code(response) |
| 172 | + |
| 173 | + assert 'f"Name: {name}' in result |
| 174 | + |
| 175 | + def test_preserves_newlines_in_code(self): |
| 176 | + response = """```python |
| 177 | +import matplotlib.pyplot as plt |
| 178 | +
|
| 179 | +
|
| 180 | +def func1(): |
| 181 | + pass |
| 182 | +
|
| 183 | +
|
| 184 | +def func2(): |
| 185 | + pass |
| 186 | +```""" |
| 187 | + |
| 188 | + result = extract_and_validate_code(response) |
| 189 | + |
| 190 | + # Should preserve blank lines |
| 191 | + assert "\n\n" in result |
| 192 | + |
| 193 | + |
| 194 | +class TestRetryWithBackoff: |
| 195 | + """Tests for retry_with_backoff function.""" |
| 196 | + |
| 197 | + def test_success_on_first_try(self): |
| 198 | + func = MagicMock(return_value="success") |
| 199 | + |
| 200 | + result = retry_with_backoff(func, max_retries=3) |
| 201 | + |
| 202 | + assert result == "success" |
| 203 | + assert func.call_count == 1 |
| 204 | + |
| 205 | + def test_retry_on_rate_limit_error(self): |
| 206 | + from anthropic import RateLimitError |
| 207 | + |
| 208 | + mock_response = MagicMock() |
| 209 | + mock_response.status_code = 429 |
| 210 | + |
| 211 | + func = MagicMock( |
| 212 | + side_effect=[RateLimitError(message="rate limited", response=mock_response, body={}), "success"] |
| 213 | + ) |
| 214 | + |
| 215 | + with patch("time.sleep"): # Skip actual sleep |
| 216 | + result = retry_with_backoff(func, max_retries=3, initial_delay=0.01) |
| 217 | + |
| 218 | + assert result == "success" |
| 219 | + assert func.call_count == 2 |
| 220 | + |
| 221 | + def test_retry_on_connection_error(self): |
| 222 | + from anthropic import APIConnectionError |
| 223 | + |
| 224 | + mock_request = MagicMock() |
| 225 | + |
| 226 | + func = MagicMock( |
| 227 | + side_effect=[ |
| 228 | + APIConnectionError(message="connection failed", request=mock_request), |
| 229 | + APIConnectionError(message="connection failed again", request=mock_request), |
| 230 | + "success", |
| 231 | + ] |
| 232 | + ) |
| 233 | + |
| 234 | + with patch("time.sleep"): |
| 235 | + result = retry_with_backoff(func, max_retries=3, initial_delay=0.01) |
| 236 | + |
| 237 | + assert result == "success" |
| 238 | + assert func.call_count == 3 |
| 239 | + |
| 240 | + def test_max_retries_exceeded_raises(self): |
| 241 | + from anthropic import RateLimitError |
| 242 | + |
| 243 | + mock_response = MagicMock() |
| 244 | + mock_response.status_code = 429 |
| 245 | + |
| 246 | + func = MagicMock(side_effect=RateLimitError(message="rate limited", response=mock_response, body={})) |
| 247 | + |
| 248 | + with patch("time.sleep"): |
| 249 | + with pytest.raises(RateLimitError): |
| 250 | + retry_with_backoff(func, max_retries=2, initial_delay=0.01) |
| 251 | + |
| 252 | + # Initial attempt + 2 retries = 3 calls |
| 253 | + assert func.call_count == 3 |
| 254 | + |
| 255 | + def test_no_retry_on_generic_api_error(self): |
| 256 | + """API errors (non-rate-limit, non-connection) should not retry.""" |
| 257 | + from anthropic import APIError |
| 258 | + |
| 259 | + mock_request = MagicMock() |
| 260 | + |
| 261 | + func = MagicMock(side_effect=APIError(message="bad request", request=mock_request, body={})) |
| 262 | + |
| 263 | + with pytest.raises(APIError, match="bad request"): |
| 264 | + retry_with_backoff(func, max_retries=3) |
| 265 | + |
| 266 | + # Should not retry |
| 267 | + assert func.call_count == 1 |
| 268 | + |
| 269 | + def test_exponential_backoff_delays(self): |
| 270 | + from anthropic import RateLimitError |
| 271 | + |
| 272 | + mock_response = MagicMock() |
| 273 | + mock_response.status_code = 429 |
| 274 | + |
| 275 | + func = MagicMock( |
| 276 | + side_effect=[ |
| 277 | + RateLimitError(message="rate limited", response=mock_response, body={}), |
| 278 | + RateLimitError(message="rate limited", response=mock_response, body={}), |
| 279 | + "success", |
| 280 | + ] |
| 281 | + ) |
| 282 | + |
| 283 | + sleep_calls = [] |
| 284 | + with patch("time.sleep", side_effect=lambda x: sleep_calls.append(x)): |
| 285 | + result = retry_with_backoff(func, max_retries=3, initial_delay=1.0, backoff_factor=2.0) |
| 286 | + |
| 287 | + assert result == "success" |
| 288 | + # First retry: 1.0s, Second retry: 2.0s (1.0 * 2.0) |
| 289 | + assert sleep_calls == [1.0, 2.0] |
| 290 | + |
| 291 | + def test_returns_result_type(self): |
| 292 | + """Test that return type matches function return type.""" |
| 293 | + func = MagicMock(return_value={"key": "value", "count": 42}) |
| 294 | + |
| 295 | + result = retry_with_backoff(func) |
| 296 | + |
| 297 | + assert result == {"key": "value", "count": 42} |
| 298 | + assert isinstance(result, dict) |
| 299 | + |
| 300 | + def test_custom_max_retries(self): |
| 301 | + from anthropic import APIConnectionError |
| 302 | + |
| 303 | + mock_request = MagicMock() |
| 304 | + |
| 305 | + func = MagicMock(side_effect=APIConnectionError(message="connection failed", request=mock_request)) |
| 306 | + |
| 307 | + with patch("time.sleep"): |
| 308 | + with pytest.raises(APIConnectionError): |
| 309 | + retry_with_backoff(func, max_retries=5, initial_delay=0.01) |
| 310 | + |
| 311 | + # Initial attempt + 5 retries = 6 calls |
| 312 | + assert func.call_count == 6 |
| 313 | + |
| 314 | + def test_zero_retries(self): |
| 315 | + from anthropic import RateLimitError |
| 316 | + |
| 317 | + mock_response = MagicMock() |
| 318 | + mock_response.status_code = 429 |
| 319 | + |
| 320 | + func = MagicMock(side_effect=RateLimitError(message="rate limited", response=mock_response, body={})) |
| 321 | + |
| 322 | + with pytest.raises(RateLimitError): |
| 323 | + retry_with_backoff(func, max_retries=0) |
| 324 | + |
| 325 | + # Only initial attempt |
| 326 | + assert func.call_count == 1 |
0 commit comments