Skip to content

Commit d9fbe75

Browse files
Fix all remaining linter errors and enable linter in CI
Co-authored-by: kevinbackhouse <4358136+kevinbackhouse@users.noreply.github.com>
1 parent b96207b commit d9fbe75

15 files changed

Lines changed: 1187 additions & 768 deletions

.github/workflows/ci.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@ jobs:
3434
run: pip install --upgrade hatch
3535

3636
- name: Run static analysis
37-
run: |
38-
# hatch fmt --check
39-
echo linter errors will be fixed in a separate PR
37+
run: hatch fmt --check
4038

4139
- name: Run tests
4240
run: hatch test --python ${{ matrix.python-version }} --cover --randomize --parallel --retries 2 --retry-delay 1

src/seclab_taskflows/mcp_servers/alert_results_models.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,16 @@
33

44
from __future__ import annotations
55

6-
from typing import Optional
7-
86
from sqlalchemy import Column, ForeignKey, Integer, Text
97
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
108

119

1210
class Base(DeclarativeBase):
1311
pass
1412

13+
1514
class AlertResults(Base):
16-
__tablename__ = 'alert_results'
15+
__tablename__ = "alert_results"
1716

1817
canonical_id: Mapped[int] = mapped_column(primary_key=True)
1918
alert_id: Mapped[str]
@@ -22,29 +21,33 @@ class AlertResults(Base):
2221
language: Mapped[str]
2322
location: Mapped[str]
2423
result: Mapped[str] = mapped_column(Text)
25-
created: Mapped[Optional[str]]
24+
created: Mapped[str | None]
2625
valid: Mapped[bool] = mapped_column(nullable=False, default=True)
2726
completed: Mapped[bool] = mapped_column(nullable=False, default=False)
2827

29-
relationship('AlertFlowGraph', cascade='all, delete')
28+
relationship("AlertFlowGraph", cascade="all, delete")
3029

3130
def __repr__(self):
32-
return (f"<AlertResults(alert_id={self.alert_id}, repo={self.repo}, "
33-
f"rule={self.rule}, language={self.language}, location={self.location}, "
34-
f"result={self.result}, created_at={self.created}, valid={self.valid}, completed={self.completed})>")
31+
return (
32+
f"<AlertResults(alert_id={self.alert_id}, repo={self.repo}, "
33+
f"rule={self.rule}, language={self.language}, location={self.location}, "
34+
f"result={self.result}, created_at={self.created}, valid={self.valid}, completed={self.completed})>"
35+
)
36+
3537

3638
class AlertFlowGraph(Base):
37-
__tablename__ = 'alert_flow_graph'
39+
__tablename__ = "alert_flow_graph"
3840

3941
id: Mapped[int] = mapped_column(primary_key=True)
40-
alert_canonical_id = Column(Integer, ForeignKey('alert_results.canonical_id', ondelete='CASCADE'))
42+
alert_canonical_id = Column(Integer, ForeignKey("alert_results.canonical_id", ondelete="CASCADE"))
4143
flow_data: Mapped[str] = mapped_column(Text)
4244
repo: Mapped[str]
43-
prev: Mapped[Optional[str]]
44-
next: Mapped[Optional[str]]
45+
prev: Mapped[str | None]
46+
next: Mapped[str | None]
4547
started: Mapped[bool] = mapped_column(nullable=False, default=False)
4648

4749
def __repr__(self):
48-
return (f"<AlertFlowGraph(alert_canonical_id={self.alert_canonical_id}, "
49-
f"flow_data={self.flow_data}, repo={self.repo}, prev={self.prev}, next={self.next}, started={self.started})>")
50-
50+
return (
51+
f"<AlertFlowGraph(alert_canonical_id={self.alert_canonical_id}, "
52+
f"flow_data={self.flow_data}, repo={self.repo}, prev={self.prev}, next={self.next}, started={self.started})>"
53+
)

src/seclab_taskflows/mcp_servers/codeql_python/codeql_sqlite_models.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33

44
from __future__ import annotations
55

6-
from typing import Optional
7-
86
from sqlalchemy import Text
97
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
108

@@ -14,16 +12,18 @@ class Base(DeclarativeBase):
1412

1513

1614
class Source(Base):
17-
__tablename__ = 'source'
15+
__tablename__ = "source"
1816

1917
id: Mapped[int] = mapped_column(primary_key=True)
2018
repo: Mapped[str]
2119
source_location: Mapped[str]
2220
line: Mapped[int]
2321
source_type: Mapped[str]
24-
notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
22+
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
2523

2624
def __repr__(self):
27-
return (f"<Source(id={self.id}, repo={self.repo}, "
28-
f"location={self.source_location}, line={self.line}, source_type={self.source_type}, "
29-
f"notes={self.notes})>")
25+
return (
26+
f"<Source(id={self.id}, repo={self.repo}, "
27+
f"location={self.source_location}, line={self.line}, source_type={self.source_type}, "
28+
f"notes={self.notes})>"
29+
)

src/seclab_taskflows/mcp_servers/codeql_python/mcp_server.py

Lines changed: 75 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -11,35 +11,34 @@
1111
import subprocess
1212
from pathlib import Path
1313

14-
#from mcp.server.fastmcp import FastMCP, Context
1514
from fastmcp import FastMCP # use FastMCP 2.0
1615
from pydantic import Field
1716
from seclab_taskflow_agent.mcp_servers.codeql.client import _debug_log, run_query
1817
from seclab_taskflow_agent.path_utils import log_file_name, mcp_data_dir
1918
from sqlalchemy import create_engine
2019
from sqlalchemy.orm import Session
2120

22-
from seclab_taskflows.mcp_servers.utils import process_repo
2321
from seclab_taskflows.mcp_servers.codeql_python.codeql_sqlite_models import Base, Source
22+
from seclab_taskflows.mcp_servers.utils import process_repo
2423

2524
logging.basicConfig(
2625
level=logging.DEBUG,
27-
format='%(asctime)s - %(levelname)s - %(message)s',
28-
filename=log_file_name('mcp_codeql_python.log'),
29-
filemode='a'
26+
format="%(asctime)s - %(levelname)s - %(message)s",
27+
filename=log_file_name("mcp_codeql_python.log"),
28+
filemode="a",
3029
)
3130

32-
MEMORY = mcp_data_dir('seclab-taskflows', 'codeql', 'DATA_DIR')
33-
CODEQL_DBS_BASE_PATH = mcp_data_dir('seclab-taskflows', 'codeql', 'CODEQL_DBS_BASE_PATH')
31+
logger = logging.getLogger(__name__)
32+
33+
MEMORY = mcp_data_dir("seclab-taskflows", "codeql", "DATA_DIR")
34+
CODEQL_DBS_BASE_PATH = mcp_data_dir("seclab-taskflows", "codeql", "CODEQL_DBS_BASE_PATH")
3435

3536
mcp = FastMCP("CodeQL-Python")
3637

3738
# tool name -> templated query lookup for supported languages
3839
TEMPLATED_QUERY_PATHS = {
3940
# to add a language, port the templated query pack and add its definition here
40-
'python': {
41-
'remote_sources': 'queries/mcp-python/remote_sources.ql'
42-
}
41+
"python": {"remote_sources": "queries/mcp-python/remote_sources.ql"}
4342
}
4443

4544

@@ -50,9 +49,10 @@ def source_to_dict(result):
5049
"source_location": result.source_location,
5150
"line": result.line,
5251
"source_type": result.source_type,
53-
"notes": result.notes
52+
"notes": result.notes,
5453
}
5554

55+
5656
def _resolve_query_path(language: str, query: str) -> Path:
5757
if language not in TEMPLATED_QUERY_PATHS:
5858
msg = f"Error: Language `{language}` not supported!"
@@ -67,7 +67,7 @@ def _resolve_query_path(language: str, query: str) -> Path:
6767
def _resolve_db_path(relative_db_path: str | Path):
6868
# path joins will return "/B" if "/A" / "////B" etc. as well
6969
# not windows compatible and probably needs additional hardening
70-
relative_db_path = str(relative_db_path).strip().lstrip('/')
70+
relative_db_path = str(relative_db_path).strip().lstrip("/")
7171
relative_db_path = Path(relative_db_path)
7272
absolute_path = (CODEQL_DBS_BASE_PATH / relative_db_path).resolve()
7373
if not absolute_path.is_relative_to(CODEQL_DBS_BASE_PATH.resolve()):
@@ -79,37 +79,38 @@ def _resolve_db_path(relative_db_path: str | Path):
7979
raise RuntimeError(msg)
8080
return str(absolute_path)
8181

82+
8283
# This sqlite database is specifically made for CodeQL for Python MCP.
8384
class CodeqlSqliteBackend:
8485
def __init__(self, memcache_state_dir: str):
8586
self.memcache_state_dir = memcache_state_dir
8687
if not Path(self.memcache_state_dir).exists():
87-
db_dir = 'sqlite://'
88+
db_dir = "sqlite://"
8889
else:
89-
db_dir = f'sqlite:///{self.memcache_state_dir}/codeql_sqlite.db'
90+
db_dir = f"sqlite:///{self.memcache_state_dir}/codeql_sqlite.db"
9091
self.engine = create_engine(db_dir, echo=False)
9192
Base.metadata.create_all(self.engine, tables=[Source.__table__])
9293

93-
94-
def store_new_source(self, repo, source_location, line, source_type, notes, *, update = False):
94+
def store_new_source(self, repo, source_location, line, source_type, notes, *, update=False):
9595
with Session(self.engine) as session:
96-
existing = session.query(Source).filter_by(repo = repo, source_location = source_location, line = line).first()
96+
existing = session.query(Source).filter_by(repo=repo, source_location=source_location, line=line).first()
9797
if existing:
9898
existing.notes = (existing.notes or "") + notes
9999
session.commit()
100100
return f"Updated notes for source at {source_location}, line {line} in {repo}."
101101
if update:
102102
return f"No source exists at repo {repo}, location {source_location}, line {line} to update."
103-
new_source = Source(repo = repo, source_location = source_location, line = line, source_type = source_type, notes = notes)
103+
new_source = Source(
104+
repo=repo, source_location=source_location, line=line, source_type=source_type, notes=notes
105+
)
104106
session.add(new_source)
105107
session.commit()
106108
return f"Added new source for {source_location} in {repo}."
107109

108110
def get_sources(self, repo):
109111
with Session(self.engine) as session:
110-
results = session.query(Source).filter_by(repo = repo).all()
111-
sources = [source_to_dict(source) for source in results]
112-
return sources
112+
results = session.query(Source).filter_by(repo=repo).all()
113+
return [source_to_dict(source) for source in results]
113114

114115

115116
# our query result format is: "human readable template {val0} {val1},'key0,key1',val0,val1"
@@ -121,8 +122,8 @@ def _csv_parse(raw):
121122
if i == 0:
122123
continue
123124
# col1 has what we care about, but offer flexibility
124-
keys = row[1].split(',')
125-
this_obj = {'description': row[0].format(*row[2:])}
125+
keys = row[1].split(",")
126+
this_obj = {"description": row[0].format(*row[2:])}
126127
for j, k in enumerate(keys):
127128
this_obj[k.strip()] = row[j + 2]
128129
results.append(this_obj)
@@ -143,27 +144,32 @@ def _run_query(query_name: str, database_path: str, language: str, template_valu
143144
except RuntimeError:
144145
return f"The query {query_name} is not supported for language: {language}"
145146
try:
146-
csv = run_query(Path(__file__).parent.resolve() /
147-
query_path,
148-
database_path,
149-
fmt='csv',
150-
template_values=template_values,
151-
log_stderr=True)
147+
csv = run_query(
148+
Path(__file__).parent.resolve() / query_path,
149+
database_path,
150+
fmt="csv",
151+
template_values=template_values,
152+
log_stderr=True,
153+
)
152154
return _csv_parse(csv)
153-
except Exception as e:
155+
except (subprocess.CalledProcessError, FileNotFoundError) as e:
154156
return f"The query {query_name} encountered an error: {e}"
155157

158+
156159
backend = CodeqlSqliteBackend(MEMORY)
157160

161+
158162
@mcp.tool()
159-
def remote_sources(owner: str = Field(description="The owner of the GitHub repository"),
160-
repo: str = Field(description="The name of the GitHub repository"),
161-
database_path: str = Field(description="The CodeQL database path."),
162-
language: str = Field(description="The language used for the CodeQL database.")):
163+
def remote_sources(
164+
owner: str = Field(description="The owner of the GitHub repository"),
165+
repo: str = Field(description="The name of the GitHub repository"),
166+
database_path: str = Field(description="The CodeQL database path."),
167+
language: str = Field(description="The language used for the CodeQL database."),
168+
):
163169
"""List all remote sources and their locations in a CodeQL database, then store the results in a database."""
164170

165171
repo = process_repo(owner, repo)
166-
results = _run_query('remote_sources', database_path, language, {})
172+
results = _run_query("remote_sources", database_path, language, {})
167173

168174
# Check if results is an error (list of strings) or valid data (list of dicts)
169175
if isinstance(results, str):
@@ -174,53 +180,67 @@ def remote_sources(owner: str = Field(description="The owner of the GitHub repos
174180
for result in results:
175181
backend.store_new_source(
176182
repo=repo,
177-
source_location=result.get('location', ''),
178-
source_type=result.get('source', ''),
179-
line=int(result.get('line', '0')),
180-
notes=None, #result.get('description', ''),
181-
update=False
183+
source_location=result.get("location", ""),
184+
source_type=result.get("source", ""),
185+
line=int(result.get("line", "0")),
186+
notes=None, # result.get('description', ''),
187+
update=False,
182188
)
183189
stored_count += 1
184190

185191
return f"Stored {stored_count} remote sources in {repo}."
186192

193+
187194
@mcp.tool()
188-
def fetch_sources(owner: str = Field(description="The owner of the GitHub repository"),
189-
repo: str = Field(description="The name of the GitHub repository")):
195+
def fetch_sources(
196+
owner: str = Field(description="The owner of the GitHub repository"),
197+
repo: str = Field(description="The name of the GitHub repository"),
198+
):
190199
"""
191200
Fetch all sources from the repo
192201
"""
193202
repo = process_repo(owner, repo)
194203
return json.dumps(backend.get_sources(repo))
195204

205+
196206
@mcp.tool()
197-
def add_source_notes(owner: str = Field(description="The owner of the GitHub repository"),
198-
repo: str = Field(description="The name of the GitHub repository"),
199-
source_location: str = Field(description="The path to the file"),
200-
line: int = Field(description="The line number of the source"),
201-
notes: str = Field(description="The notes to append to this source")):
207+
def add_source_notes(
208+
owner: str = Field(description="The owner of the GitHub repository"),
209+
repo: str = Field(description="The name of the GitHub repository"),
210+
source_location: str = Field(description="The path to the file"),
211+
line: int = Field(description="The line number of the source"),
212+
notes: str = Field(description="The notes to append to this source"),
213+
):
202214
"""
203215
Add new notes to an existing source. The notes will be appended to any existing notes.
204216
"""
205217
repo = process_repo(owner, repo)
206-
return backend.store_new_source(repo = repo, source_location = source_location, line = line, source_type = "", notes = notes, update=True)
218+
return backend.store_new_source(
219+
repo=repo, source_location=source_location, line=line, source_type="", notes=notes, update=True
220+
)
221+
207222

208223
@mcp.tool()
209-
def clear_codeql_repo(owner: str = Field(description="The owner of the GitHub repository"),
210-
repo: str = Field(description="The name of the GitHub repository")):
224+
def clear_codeql_repo(
225+
owner: str = Field(description="The owner of the GitHub repository"),
226+
repo: str = Field(description="The name of the GitHub repository"),
227+
):
211228
"""
212229
Clear all data for a given repo from the database
213230
"""
214231
repo = process_repo(owner, repo)
215232
with Session(backend.engine) as session:
216-
deleted_sources = session.query(Source).filter_by(repo = repo).delete()
233+
deleted_sources = session.query(Source).filter_by(repo=repo).delete()
217234
session.commit()
218235
return f"Cleared {deleted_sources} sources from repo {repo}."
219236

237+
220238
if __name__ == "__main__":
221239
# Check if codeql/python-all pack is installed, if not install it
222-
if not os.path.isdir('/.codeql/packages/codeql/python-all'):
223-
pack_path = importlib.resources.files('seclab_taskflows.mcp_servers.codeql_python.queries').joinpath('mcp-python')
224-
print(f"Installing CodeQL pack from {pack_path}")
225-
subprocess.run(["codeql", "pack", "install", pack_path], check=False)
240+
if not os.path.isdir("/.codeql/packages/codeql/python-all"):
241+
pack_path = importlib.resources.files("seclab_taskflows.mcp_servers.codeql_python.queries").joinpath(
242+
"mcp-python"
243+
)
244+
logger.info("Installing CodeQL pack from %s", pack_path)
245+
subprocess.run(["/usr/local/bin/codeql", "pack", "install", pack_path], check=False)
226246
mcp.run(show_banner=False, transport="http", host="127.0.0.1", port=9998)

0 commit comments

Comments
 (0)