Skip to content

Commit 7f9137a

Browse files
authored
Kernelguard pre-queue check addition (#473)
* Add KernelGuard integration for submission pre-checks - Introduced KernelGuard for validating submissions before processing. - Implemented error handling for rejected submissions in the backend. - Updated database methods to mark submissions as hacked when flagged. - Enhanced tests to cover new KernelGuard functionality and error scenarios. - Added a new kernelguard.py module for managing submission analysis and pre-checks. * Update Python version requirement and enhance KernelGuard integration - Updated Python version requirement from 3.10 to 3.11 in pyproject.toml and uv.lock. - Added `kernelguard` dependency to manage submission pre-checks. - Enhanced error handling in submission processes to include KernelGuard rejection scenarios. - Implemented pre-check logic in the submission workflow to prevent blocked submissions from queuing. - Updated tests to reflect changes in submission handling and pre-check logic. * ruff fix --------- Co-authored-by: Sinatras <SinatrasC@users.noreply.github.com>
1 parent 2ed4719 commit 7f9137a

11 files changed

Lines changed: 437 additions & 330 deletions

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ name = "discord-cluster-manager"
77
version = "0.1.0"
88
description = "Discord bot for managing compute clusters and running kernel benchmarks"
99
readme = "README.md"
10-
requires-python = ">=3.10"
10+
requires-python = ">=3.11"
1111
dependencies = [
1212
"PyGithub",
1313
"aiohttp",
@@ -25,6 +25,7 @@ dependencies = [
2525
"jinja2",
2626
"huggingface-hub>=0.20",
2727
"pyarrow>=14.0",
28+
"kernelguard>=0.1.1",
2829
]
2930

3031
[project.optional-dependencies]

src/kernelbot/api/api_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from kernelbot.env import env
77
from libkernelbot.backend import KernelBackend
88
from libkernelbot.consts import SubmissionMode
9+
from libkernelbot.kernelguard import KernelGuardRejected
910
from libkernelbot.leaderboard_db import LeaderboardDB
1011
from libkernelbot.report import (
1112
Log,
@@ -18,6 +19,7 @@
1819
SubmissionRequest,
1920
prepare_submission,
2021
)
22+
from libkernelbot.utils import KernelBotError
2123

2224

2325
async def _handle_discord_oauth(code: str, redirect_uri: str) -> tuple[str, str]:
@@ -154,7 +156,12 @@ async def _run_submission(
154156
raise HTTPException(status_code=400, detail="Invalid GPU type")
155157

156158
reporter = MultiProgressReporterAPI()
157-
sub_id, results = await backend.submit_full(req, mode, reporter)
159+
try:
160+
sub_id, results = await backend.submit_full(req, mode, reporter)
161+
except KernelGuardRejected as e:
162+
raise HTTPException(status_code=400, detail=str(e)) from e
163+
except KernelBotError as e:
164+
raise HTTPException(status_code=getattr(e, "http_code", 500), detail=str(e)) from e
158165
return results, [rep.get_message() + "\n" + rep.long_report for rep in reporter.runs]
159166

160167

src/kernelbot/api/main.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from libkernelbot.background_submission_manager import BackgroundSubmissionManager
1616
from libkernelbot.consts import SubmissionMode
1717
from libkernelbot.db_types import IdentityType
18+
from libkernelbot.kernelguard import KernelGuardRejected, enforce_submission_precheck, should_precheck_submission
1819
from libkernelbot.leaderboard_db import LeaderboardDB, LeaderboardRankedEntry
1920
from libkernelbot.problem_sync import sync_problems
2021
from libkernelbot.submission import (
@@ -563,6 +564,18 @@ async def run_submission_async(
563564
if not req.gpus or len(req.gpus) != 1:
564565
raise HTTPException(status_code=400, detail="Invalid GPU type")
565566

567+
# run KernelGuard pre-check before enqueuing to avoid filling the queue with blocked submissions
568+
if should_precheck_submission(submission_mode_enum):
569+
try:
570+
await asyncio.wait_for(
571+
asyncio.to_thread(enforce_submission_precheck, req.code, req.file_name),
572+
timeout=5.0,
573+
)
574+
except asyncio.TimeoutError as e:
575+
raise HTTPException(status_code=504, detail="KernelGuard pre-check timed out") from e
576+
except KernelGuardRejected as e:
577+
raise HTTPException(status_code=400, detail=str(e)) from e
578+
566579
# put submission request to background manager to run in background
567580
sub_id, job_status_id = await enqueue_background_job(
568581
req, submission_mode_enum, backend_instance, background_submission_manager

src/libkernelbot/backend.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
from typing import Optional
66

77
from libkernelbot.consts import GPU, GPU_TO_SM, SubmissionMode, get_gpu_by_name, get_mode_category
8+
from libkernelbot.kernelguard import (
9+
KernelGuardRejected,
10+
enforce_submission_precheck,
11+
should_precheck_submission,
12+
)
813
from libkernelbot.launchers import Launcher
914
from libkernelbot.leaderboard_db import LeaderboardDB
1015
from libkernelbot.report import (
@@ -53,10 +58,11 @@ async def submit_full(
5358
mode: SubmissionMode,
5459
reporter: MultiProgressReporter,
5560
pre_sub_id: Optional[int] = None,
61+
skip_precheck: bool = False,
5662
):
5763
"""
5864
pre_sub_id is used to pass the submission id which is created beforehand.
59-
65+
skip_precheck skips the KernelGuard pre-check (use when the caller already ran it).
6066
"""
6167
if pre_sub_id is not None:
6268
sub_id = pre_sub_id
@@ -72,7 +78,29 @@ async def submit_full(
7278
mode_category=req.mode_category or get_mode_category(mode),
7379
)
7480
selected_gpus = [get_gpu_by_name(gpu) for gpu in req.gpus]
81+
submission_started = False
7582
try:
83+
if not skip_precheck and should_precheck_submission(mode):
84+
try:
85+
await asyncio.to_thread(enforce_submission_precheck, req.code, req.file_name)
86+
except KernelGuardRejected as exc:
87+
logger.error(
88+
"Submission %s rejected by precheck: file=%s, mode=%s, error=%s",
89+
sub_id, req.file_name, mode, str(exc)
90+
)
91+
with self.db as db:
92+
db.mark_submission_hacked(sub_id, error=str(exc))
93+
raise
94+
except Exception as exc:
95+
logger.error(
96+
"Submission %s precheck unavailable: file=%s, mode=%s, error=%s",
97+
sub_id, req.file_name, mode, str(exc)
98+
)
99+
with self.db as db:
100+
db.mark_submission_done(sub_id)
101+
raise
102+
103+
submission_started = True
76104
tasks = [
77105
self.submit_leaderboard(
78106
sub_id,
@@ -106,8 +134,9 @@ async def submit_full(
106134
)
107135
results = await asyncio.gather(*tasks)
108136
finally:
109-
with self.db as db:
110-
db.mark_submission_done(sub_id)
137+
if submission_started:
138+
with self.db as db:
139+
db.mark_submission_done(sub_id)
111140
return sub_id, results
112141

113142
async def submit_leaderboard( # noqa: C901

src/libkernelbot/background_submission_manager.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from libkernelbot.backend import KernelBackend
88
from libkernelbot.consts import SubmissionMode
9+
from libkernelbot.kernelguard import KernelGuardRejected
910
from libkernelbot.report import MultiProgressReporter, RunProgressReporter, RunResultReport
1011
from libkernelbot.submission import ProcessedSubmissionRequest
1112
from libkernelbot.utils import setup_logging
@@ -233,7 +234,7 @@ async def heartbeat():
233234
reporter = BackgroundSubmissionManagerReporter()
234235
await asyncio.wait_for(
235236
self.backend.submit_full(
236-
item.req, item.mode, reporter, sub_id
237+
item.req, item.mode, reporter, sub_id, skip_precheck=True
237238
),
238239
timeout=HARD_TIMEOUT_SEC,
239240
)
@@ -252,6 +253,22 @@ async def heartbeat():
252253
last_heartbeat=ts,
253254
error="hard timeout reached",
254255
)
256+
except KernelGuardRejected as e:
257+
ts = dt.datetime.now(dt.timezone.utc)
258+
logger.info("[Background Job] submission %s flagged as hacked", sub_id)
259+
try:
260+
with self.backend.db as db:
261+
db.upsert_submission_job_status(
262+
sub_id,
263+
status="hacked",
264+
last_heartbeat=ts,
265+
error=str(e),
266+
)
267+
except Exception:
268+
logger.error(
269+
"[Background Job] Failed to write hacked status for submission %s",
270+
sub_id,
271+
)
255272
except Exception as e:
256273
ts = dt.datetime.now(dt.timezone.utc)
257274
logger.error(

src/libkernelbot/kernelguard.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import json
2+
import os
3+
import shlex
4+
import shutil
5+
import subprocess
6+
from typing import Any
7+
8+
from libkernelbot.consts import SubmissionMode
9+
from libkernelbot.utils import KernelBotError, limit_length, setup_logging
10+
11+
logger = setup_logging(__name__)
12+
13+
_TRUE_VALUES = {"1", "true", "yes", "on"}
14+
_DEFAULT_TIMEOUT_SEC = 30
15+
_GUARDED_MODES = frozenset(
16+
{
17+
SubmissionMode.BENCHMARK,
18+
SubmissionMode.PROFILE,
19+
SubmissionMode.LEADERBOARD,
20+
SubmissionMode.PRIVATE,
21+
}
22+
)
23+
24+
25+
class KernelGuardRejected(KernelBotError):
26+
def __init__(self, message: str, result: dict[str, Any]):
27+
super().__init__(message)
28+
self.result = result
29+
30+
31+
def _env_enabled(name: str, default: bool = False) -> bool:
32+
raw = os.getenv(name)
33+
if raw is None:
34+
return default
35+
return raw.strip().lower() in _TRUE_VALUES
36+
37+
38+
def should_precheck_submission(mode: SubmissionMode) -> bool:
39+
return _env_enabled("KERNELGUARD_ENABLED") and mode in _GUARDED_MODES
40+
41+
42+
def _timeout_sec() -> int:
43+
raw = os.getenv("KERNELGUARD_TIMEOUT_SEC", str(_DEFAULT_TIMEOUT_SEC)).strip()
44+
try:
45+
return max(1, int(raw))
46+
except ValueError:
47+
logger.warning("Invalid KERNELGUARD_TIMEOUT_SEC=%r, using %d", raw, _DEFAULT_TIMEOUT_SEC)
48+
return _DEFAULT_TIMEOUT_SEC
49+
50+
51+
def _profile() -> str | None:
52+
raw = os.getenv("KERNELGUARD_PROFILE", "").strip()
53+
return raw or None
54+
55+
56+
def _config_path() -> str | None:
57+
raw = os.getenv("KERNELGUARD_CONFIG", "").strip()
58+
return raw or None
59+
60+
61+
def _fail_open_enabled() -> bool:
62+
return _env_enabled("KERNELGUARD_FAIL_OPEN")
63+
64+
65+
def _default_command() -> list[str]:
66+
for candidate in ("kernelguard", "kguard"):
67+
if shutil.which(candidate):
68+
return [candidate]
69+
if shutil.which("uvx"):
70+
return ["uvx", "kernelguard"]
71+
raise FileNotFoundError("Could not find `kernelguard`, `kguard`, or `uvx` in PATH")
72+
73+
74+
def _command() -> list[str]:
75+
raw = os.getenv("KERNELGUARD_COMMAND", "").strip()
76+
if raw:
77+
return shlex.split(raw)
78+
return _default_command()
79+
80+
81+
def _analyze_with_cli(code: str) -> dict[str, Any]:
82+
cmd = [*_command()]
83+
profile = _profile()
84+
config_path = _config_path()
85+
if profile is not None:
86+
cmd.extend(["--profile", profile])
87+
if config_path is not None:
88+
cmd.extend(["--config", config_path])
89+
cmd.append("--api-mode")
90+
91+
proc = subprocess.run(
92+
cmd,
93+
input=code,
94+
text=True,
95+
capture_output=True,
96+
timeout=_timeout_sec(),
97+
check=False,
98+
)
99+
if proc.returncode != 0:
100+
stderr = limit_length(proc.stderr.strip(), 300) if proc.stderr else ""
101+
stdout = limit_length(proc.stdout.strip(), 300) if proc.stdout else ""
102+
raise RuntimeError(
103+
"KernelGuard command failed "
104+
f"(exit={proc.returncode}, stdout={stdout!r}, stderr={stderr!r})"
105+
)
106+
107+
lines = [line for line in proc.stdout.splitlines() if line.strip()]
108+
if not lines:
109+
raise RuntimeError("KernelGuard returned no JSON result")
110+
111+
try:
112+
result = json.loads(lines[-1])
113+
except json.JSONDecodeError as exc:
114+
raise RuntimeError(f"KernelGuard returned invalid JSON: {lines[-1]!r}") from exc
115+
116+
if not isinstance(result, dict):
117+
raise RuntimeError("KernelGuard returned a non-object JSON payload")
118+
return result
119+
120+
121+
def analyze_submission(code: str) -> dict[str, Any]:
122+
# Always use the single-shot CLI path so KERNELGUARD_TIMEOUT_SEC is enforced.
123+
return _analyze_with_cli(code)
124+
125+
126+
def enforce_submission_precheck(code: str, file_name: str) -> dict[str, Any] | None:
127+
if not _env_enabled("KERNELGUARD_ENABLED"):
128+
return None
129+
130+
try:
131+
result = analyze_submission(code)
132+
except Exception as exc:
133+
logger.warning("KernelGuard pre-check failed for %s", file_name, exc_info=exc)
134+
if _fail_open_enabled():
135+
return None
136+
raise KernelBotError(
137+
"KernelGuard pre-check is unavailable right now. Please try again later.",
138+
code=503,
139+
) from exc
140+
141+
classification = str(result.get("classification", "unknown"))
142+
if result.get("should_filter"):
143+
patterns = sorted(
144+
{
145+
str(item.get("pattern", "unknown"))
146+
for item in result.get("matched_patterns", [])
147+
if isinstance(item, dict)
148+
}
149+
)
150+
reason = str(result.get("filter_reason") or classification)
151+
details = f"Submission rejected by KernelGuard pre-check ({reason})"
152+
if patterns:
153+
details += f". Matched rules: {', '.join(patterns)}"
154+
raise KernelGuardRejected(details + ".", result=result)
155+
156+
if classification != "valid":
157+
logger.info("KernelGuard classified %s as %s", file_name, classification)
158+
159+
return result

src/libkernelbot/leaderboard_db.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,37 @@ def mark_submission_done(
377377
self.connection.rollback() # Ensure rollback if error occurs
378378
raise KernelBotError("Error while finalizing submission") from e
379379

380+
def mark_submission_hacked(self, submission: int, error: str | None = None) -> None:
381+
try:
382+
now = datetime.datetime.now(datetime.timezone.utc)
383+
self.cursor.execute(
384+
"""
385+
UPDATE leaderboard.submission
386+
SET done = TRUE, status = 'hacked'
387+
WHERE id = %s
388+
""",
389+
(submission,),
390+
)
391+
self.cursor.execute(
392+
"""
393+
INSERT INTO leaderboard.submission_job_status AS s
394+
(submission_id, status, error, last_heartbeat)
395+
VALUES
396+
(%s, %s, %s, %s)
397+
ON CONFLICT (submission_id) DO UPDATE
398+
SET
399+
status = EXCLUDED.status,
400+
error = COALESCE(EXCLUDED.error, s.error),
401+
last_heartbeat = EXCLUDED.last_heartbeat
402+
""",
403+
(submission, "hacked", error, now),
404+
)
405+
self.connection.commit()
406+
except psycopg2.Error as e:
407+
logger.error("Could not mark submission '%s' as hacked.", submission, exc_info=e)
408+
self.connection.rollback()
409+
raise KernelBotError("Error while recording hacked submission") from e
410+
380411
def update_heartbeat_if_active(self, sub_id: int, ts: datetime.datetime) -> None:
381412
try:
382413
self.cursor.execute(

0 commit comments

Comments
 (0)