Skip to content

Commit d87f113

Browse files
committed
Add tests for GPU parsing, concurrency limits, network config, health
GPU: valid types, invalid type, case-insensitive. Concurrency: 429 on limit, counter decrements on error. Network: CIDR parsing, empty CIDR. Health: endpoint returns ok.
1 parent 12e3973 commit d87f113

1 file changed

Lines changed: 113 additions & 0 deletions

File tree

tests/test_app.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,3 +425,116 @@ def test_sanitize_error_message_empty():
425425

426426
assert _sanitize_error_message("") == "[empty response]"
427427
assert _sanitize_error_message(None) == "[empty response]"
428+
429+
430+
# =============================================================================
431+
# GPU / network / concurrency / health tests
432+
# =============================================================================
433+
434+
435+
def test_get_gpu_config_valid_types():
436+
from app import _get_gpu_config
437+
438+
for gpu_key in ["t4", "l4", "a100", "a100-80gb", "h100"]:
439+
with patch("app.modal") as mock_modal:
440+
mock_modal.gpu.T4.return_value = "mock_gpu"
441+
mock_modal.gpu.L4.return_value = "mock_gpu"
442+
mock_modal.gpu.A100.return_value = "mock_gpu"
443+
mock_modal.gpu.A100_80GB.return_value = "mock_gpu"
444+
mock_modal.gpu.H100.return_value = "mock_gpu"
445+
result = _get_gpu_config(gpu_key)
446+
assert result is not None
447+
448+
449+
def test_get_gpu_config_invalid_type():
450+
from app import _get_gpu_config
451+
452+
with patch("app.modal"):
453+
result = _get_gpu_config("nonexistent-gpu")
454+
assert result is None
455+
456+
457+
def test_get_gpu_config_case_insensitive():
458+
from app import _get_gpu_config
459+
460+
with patch("app.modal") as mock_modal:
461+
mock_modal.gpu.T4.return_value = "mock_gpu"
462+
with patch.dict("app.GPU_LABEL_TO_ATTR", {"T4": "T4"}, clear=False):
463+
result = _get_gpu_config("T4")
464+
assert result is not None
465+
466+
467+
@pytest.mark.asyncio
468+
async def test_concurrency_limit_returns_429():
469+
from fastapi import HTTPException
470+
471+
import app as app_module
472+
from app import github_webhook
473+
474+
original = app_module.MAX_CONCURRENT_PER_REPO
475+
app_module.MAX_CONCURRENT_PER_REPO = 1
476+
app_module._concurrent_jobs["owner/repo"] = 1
477+
478+
body = _make_webhook_body()
479+
request = _make_request(body)
480+
481+
try:
482+
with pytest.raises(HTTPException) as exc_info:
483+
await github_webhook(request)
484+
assert exc_info.value.status_code == 429
485+
finally:
486+
app_module.MAX_CONCURRENT_PER_REPO = original
487+
app_module._concurrent_jobs.clear()
488+
489+
490+
@pytest.mark.asyncio
491+
async def test_concurrency_counter_decrements():
492+
import httpx
493+
494+
import app as app_module
495+
from app import github_webhook
496+
497+
original = app_module.MAX_CONCURRENT_PER_REPO
498+
app_module.MAX_CONCURRENT_PER_REPO = 5
499+
app_module._concurrent_jobs.clear()
500+
501+
body = _make_webhook_body()
502+
request = _make_request(body)
503+
504+
req = httpx.Request("POST", "http://test")
505+
mock_response = httpx.Response(200, json={"encoded_jit_config": "jit-config"}, request=req)
506+
507+
try:
508+
with patch("app._call_github_api", new=AsyncMock(return_value=mock_response)):
509+
with patch("app.modal.Sandbox.create", side_effect=Exception("sandbox failure")):
510+
with pytest.raises(Exception):
511+
await github_webhook(request)
512+
finally:
513+
assert app_module._concurrent_jobs.get("owner/repo", 0) == 0
514+
app_module.MAX_CONCURRENT_PER_REPO = original
515+
516+
517+
def test_network_config_defaults():
518+
import app
519+
520+
assert app.BLOCK_NETWORK is False or app.BLOCK_NETWORK is True
521+
522+
523+
def test_allowed_cidrs_parsing():
524+
cidrs_str = "10.0.0.0/8, 172.16.0.0/12"
525+
cidrs = [c.strip() for c in cidrs_str.split(",") if c.strip()] if cidrs_str else None
526+
assert cidrs == ["10.0.0.0/8", "172.16.0.0/12"]
527+
528+
529+
def test_allowed_cidrs_empty():
530+
cidrs_str = ""
531+
cidrs = [c.strip() for c in cidrs_str.split(",") if c.strip()] if cidrs_str else None
532+
assert cidrs is None
533+
534+
535+
@pytest.mark.asyncio
536+
async def test_health_endpoint():
537+
from app import health
538+
539+
result = await health(MagicMock())
540+
assert result == {"status": "ok"}

0 commit comments

Comments
 (0)