Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ ignore = [
"EM102", # Exception f-strings
"G004", # Logging f-strings
"T201", # print() used for user output
"TRY003", # Raise with inline message strings

# Backwards-compatibility suppressions for existing code
"A001", # Variable shadows built-in
Expand Down Expand Up @@ -253,3 +252,4 @@ ignore = [

[tool.ruff.lint.per-file-ignores]
"tests/*" = ["S101", "PLR2004"]
"src/seclab_taskflow_agent/mcp_servers/codeql/jsonrpyc/*" = ["TRY003"]
3 changes: 2 additions & 1 deletion src/seclab_taskflow_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ def __init__(
if token:
resolved_token = os.getenv(token, "")
if not resolved_token:
raise RuntimeError(f"Token env var {token!r} is not set")
msg = f"Token env var {token!r} is not set"
raise RuntimeError(msg)
else:
resolved_token = get_AI_token()

Expand Down
40 changes: 16 additions & 24 deletions src/seclab_taskflow_agent/available_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,18 +108,15 @@ def _load(self, tooltype: AvailableToolType, toolname: str) -> DocumentModel:
# Resolve package and filename from dotted path
components = toolname.rsplit(".", 1)
if len(components) != 2:
raise BadToolNameError(
f'Not a valid toolname: "{toolname}". '
f'Expected format: "packagename.filename"'
)
msg = f'Not a valid toolname: "{toolname}". Expected format: "packagename.filename"'
raise BadToolNameError(msg)
package, filename = components

try:
pkg_dir = importlib.resources.files(package)
if not pkg_dir.is_dir():
raise BadToolNameError(
f"Cannot load {toolname} because {pkg_dir} is not a valid directory."
)
msg = f"Cannot load {toolname} because {pkg_dir} is not a valid directory."
raise BadToolNameError(msg)
filepath = pkg_dir.joinpath(filename + ".yaml")
with filepath.open() as fh:
raw = yaml.safe_load(fh)
Expand All @@ -128,17 +125,14 @@ def _load(self, tooltype: AvailableToolType, toolname: str) -> DocumentModel:
header = raw.get("seclab-taskflow-agent", {})
filetype = header.get("filetype", "")
if filetype != tooltype.value:
raise FileTypeException(
f"Error in {filepath}: expected filetype {tooltype.value!r}, "
f"got {filetype!r}."
)
msg = f"Error in {filepath}: expected filetype {tooltype.value!r}, got {filetype!r}."
raise FileTypeException(msg)

# Parse into the appropriate Pydantic model
model_cls = DOCUMENT_MODELS.get(filetype)
if model_cls is None:
raise BadToolNameError(
f"Unknown filetype {filetype!r} in {toolname}"
)
msg = f"Unknown filetype {filetype!r} in {toolname}"
raise BadToolNameError(msg)

try:
doc = model_cls(**raw)
Expand All @@ -147,21 +141,19 @@ def _load(self, tooltype: AvailableToolType, toolname: str) -> DocumentModel:
for err in exc.errors():
if "Unsupported version" in str(err.get("msg", "")):
raise VersionException(str(err["msg"])) from exc
raise BadToolNameError(
f"Validation error loading {toolname}: {exc}"
) from exc

# Cache and return
msg = f"Validation error loading {toolname}: {exc}"
raise BadToolNameError(msg) from exc
if tooltype not in self._cache:
self._cache[tooltype] = {}
self._cache[tooltype][toolname] = doc
return doc

except ModuleNotFoundError as exc:
raise BadToolNameError(f"Cannot load {toolname}: {exc}") from exc
msg = f"Cannot load {toolname}: {exc}"
raise BadToolNameError(msg) from exc
except FileNotFoundError:
raise BadToolNameError(
f"Cannot load {toolname} because {filepath} is not a valid file."
)
msg = f"Cannot load {toolname} because {filepath} is not a valid file."
raise BadToolNameError(msg)
except ValueError as exc:
raise BadToolNameError(f"Cannot load {toolname}: {exc}") from exc
msg = f"Cannot load {toolname}: {exc}"
raise BadToolNameError(msg) from exc
3 changes: 2 additions & 1 deletion src/seclab_taskflow_agent/capi.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ def get_AI_token() -> str:
token = os.getenv("COPILOT_TOKEN")
if token:
return token
raise RuntimeError("AI_API_TOKEN environment variable is not set.")
msg = "AI_API_TOKEN environment variable is not set."
raise RuntimeError(msg)


# ---------------------------------------------------------------------------
Expand Down
3 changes: 2 additions & 1 deletion src/seclab_taskflow_agent/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@
def _parse_global(value: str) -> tuple[str, str]:
"""Parse a ``KEY=VALUE`` string into a (key, value) pair."""
if "=" not in value:
raise typer.BadParameter(f"Invalid global variable format: {value!r}. Expected KEY=VALUE.")
msg = f"Invalid global variable format: {value!r}. Expected KEY=VALUE."
raise typer.BadParameter(msg)
key, _, val = value.partition("=")
return key.strip(), val.strip()

Expand Down
3 changes: 2 additions & 1 deletion src/seclab_taskflow_agent/mcp_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def _print_err(line: str) -> None:
client_session_timeout_seconds=client_session_timeout,
)
case _:
raise ValueError(f"Unsupported MCP transport: {params['kind']}")
msg = f"Unsupported MCP transport: {params['kind']}"
raise ValueError(msg)

entries.append(MCPServerEntry(MCPNamespaceWrap(confirms, mcp_server), server_proc, name=tb))

Expand Down
24 changes: 16 additions & 8 deletions src/seclab_taskflow_agent/mcp_servers/codeql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,12 @@ def _server_request_run(
template_values: dict | None = None,
):
if not self.active_database:
raise RuntimeError("No Active Database")
msg = "No Active Database"
raise RuntimeError(msg)

if not self.active_connection:
raise RuntimeError("No Active Connection")
msg = "No Active Connection"
raise RuntimeError(msg)

if isinstance(quick_eval_pos, dict):
# A quick eval position contains:
Expand Down Expand Up @@ -302,7 +304,8 @@ def _format(self, query):
def _resolve_query_server(self):
help_msg = shell_command_to_string(self.codeql_cli + ["excute", "--help"])
if not re.search("query-server2", help_msg):
raise RuntimeError("Legacy server not supported!")
msg = "Legacy server not supported!"
raise RuntimeError(msg)
return "query-server2"

def _resolve_library_paths(self, query_path):
Expand Down Expand Up @@ -463,11 +466,13 @@ def _file_uri_to_path(uri):
# internally the codeql client will resolve both relative and full paths
# regardless of root directory differences
if not uri.startswith("file:///"):
raise ValueError("URI path should be formatted as absolute")
msg = "URI path should be formatted as absolute"
raise ValueError(msg)
# note: don't try to parse paths like "file://a/b" because that returns "/b", should be "file:///a/b"
parsed = urlparse(uri)
if parsed.scheme != "file":
raise ValueError(f"Not a file:// uri: {uri}")
msg = f"Not a file:// uri: {uri}"
raise ValueError(msg)
path = unquote(parsed.path)
region = None
if ":" in path:
Expand Down Expand Up @@ -605,7 +610,8 @@ def run_query(
if target:
target_pos = get_query_position(query_path, target)
if not target_pos:
raise ValueError(f"Could not resolve quick eval target for {target}")
msg = f"Could not resolve quick eval target for {target}"
raise ValueError(msg)
try:
with (
QueryServer(database, keep_alive=keep_alive, log_stderr=log_stderr) as server,
Expand Down Expand Up @@ -633,7 +639,9 @@ def run_query(
case "sarif":
result = server._bqrs_to_sarif(bqrs_path, server._query_info(query_path))
case _:
raise ValueError("Unsupported output format {fmt}")
msg = f"Unsupported output format {fmt}"
raise ValueError(msg)
except Exception as e:
raise RuntimeError(f"Error in run_query: {e}") from e
msg = f"Error in run_query: {e}"
raise RuntimeError(msg) from e
return result
9 changes: 6 additions & 3 deletions src/seclab_taskflow_agent/mcp_servers/codeql/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,12 @@
def _resolve_query_path(language: str, query: str) -> Path:
global TEMPLATED_QUERY_PATHS
if language not in TEMPLATED_QUERY_PATHS:
raise RuntimeError(f"Error: Language `{language}` not supported!")
msg = f"Error: Language `{language}` not supported!"
raise RuntimeError(msg)
query_path = TEMPLATED_QUERY_PATHS[language].get(query)
if not query_path:
raise RuntimeError(f"Error: query `{query}` not supported for `{language}`!")
msg = f"Error: query `{query}` not supported for `{language}`!"
raise RuntimeError(msg)
return Path(query_path)


Expand All @@ -69,7 +71,8 @@ def _resolve_db_path(relative_db_path: str | Path):
absolute_path = CODEQL_DBS_BASE_PATH / relative_db_path
if not absolute_path.is_dir():
_debug_log(f"Database path not found: {absolute_path}")
raise RuntimeError(f"Error: Database not found at {absolute_path}!")
msg = f"Error: Database not found at {absolute_path}!"
raise RuntimeError(msg)
return absolute_path


Expand Down
15 changes: 10 additions & 5 deletions src/seclab_taskflow_agent/mcp_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ async def async_wait_for_connection(
host = parsed.hostname
port = parsed.port
if host is None or port is None:
raise ValueError(f"URL must include a host and port: {self.url}")
msg = f"URL must include a host and port: {self.url}"
raise ValueError(msg)
deadline = asyncio.get_event_loop().time() + timeout
while True:
try:
Expand All @@ -119,7 +120,8 @@ async def async_wait_for_connection(
return
except (OSError, ConnectionRefusedError):
if asyncio.get_event_loop().time() > deadline:
raise TimeoutError(f"Could not connect to {host}:{port} after {timeout} seconds")
msg = f"Could not connect to {host}:{port} after {timeout} seconds"
raise TimeoutError(msg)
await asyncio.sleep(poll_interval)

def wait_for_connection(
Expand All @@ -139,15 +141,17 @@ def wait_for_connection(
host = parsed.hostname
port = parsed.port
if host is None or port is None:
raise ValueError(f"URL must include a host and port: {self.url}")
msg = f"URL must include a host and port: {self.url}"
raise ValueError(msg)
deadline = time.time() + timeout
while True:
try:
with socket.create_connection((host, port), timeout=2):
return
except OSError:
if time.time() > deadline:
raise TimeoutError(f"Could not connect to {host}:{port} after {timeout} seconds")
msg = f"Could not connect to {host}:{port} after {timeout} seconds"
raise TimeoutError(msg)
time.sleep(poll_interval)

def run(self) -> None:
Expand Down Expand Up @@ -216,7 +220,8 @@ def join_and_raise(self, timeout: float | None = None) -> None:
"""
self.join(timeout)
if self.is_alive():
raise RuntimeError("Process thread did not exit within timeout.")
msg = "Process thread did not exit within timeout."
raise RuntimeError(msg)
if self.exception is not None:
raise self.exception

Expand Down
6 changes: 4 additions & 2 deletions src/seclab_taskflow_agent/mcp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ def mcp_client_params(
logging.debug(f"Initializing streamable toolbox: {tb}\nargs:\n{args}\nenv:\n{env}\n")
exe = shutil.which(sp.command)
if exe is None:
raise FileNotFoundError(f"Could not resolve path to {sp.command}")
msg = f"Could not resolve path to {sp.command}"
raise FileNotFoundError(msg)
start_cmd = [exe]
if args:
for i, v in enumerate(args):
Expand All @@ -220,7 +221,8 @@ def mcp_client_params(
server_params["env"] = env

case _:
raise ValueError(f"Unsupported MCP transport {kind}")
msg = f"Unsupported MCP transport {kind}"
raise ValueError(msg)

client_params[tb] = (
server_params,
Expand Down
8 changes: 4 additions & 4 deletions src/seclab_taskflow_agent/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,8 @@ def _normalise_version(cls, v: Any) -> str:
@classmethod
def _validate_version(cls, v: str) -> str:
if v != SUPPORTED_VERSION:
raise ValueError(
f"Unsupported version: {v}. Only version {SUPPORTED_VERSION} is supported."
)
msg = f"Unsupported version: {v}. Only version {SUPPORTED_VERSION} is supported."
raise ValueError(msg)
return v


Expand Down Expand Up @@ -106,7 +105,8 @@ class TaskDefinition(BaseModel):
@model_validator(mode="after")
def _run_xor_prompt(self) -> TaskDefinition:
if self.run and self.user_prompt:
raise ValueError("shell task ('run') and prompt task ('user_prompt') are mutually exclusive")
msg = "shell task ('run') and prompt task ('user_prompt') are mutually exclusive"
raise ValueError(msg)
return self


Expand Down
Loading
Loading