Skip to content

Commit cdcd22f

Browse files
Revert formatting changes and only enable linter check in CI
Co-authored-by: kevinbackhouse <4358136+kevinbackhouse@users.noreply.github.com>
1 parent 2163a80 commit cdcd22f

15 files changed

Lines changed: 659 additions & 1062 deletions

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ jobs:
3434
run: pip install --upgrade hatch
3535

3636
- name: Run static analysis
37-
run: hatch fmt --check
37+
run: hatch fmt --linter --check
3838

3939
- name: Run tests
4040
run: hatch test --python ${{ matrix.python-version }} --cover --randomize --parallel --retries 2 --retry-delay 1

src/seclab_taskflows/mcp_servers/alert_results_models.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,11 @@
55
from sqlalchemy.orm import DeclarativeBase, mapped_column, Mapped, relationship
66
from typing import Optional
77

8-
98
class Base(DeclarativeBase):
109
pass
1110

12-
1311
class AlertResults(Base):
14-
__tablename__ = "alert_results"
12+
__tablename__ = 'alert_results'
1513

1614
canonical_id: Mapped[int] = mapped_column(primary_key=True)
1715
alert_id: Mapped[str]
@@ -24,29 +22,25 @@ class AlertResults(Base):
2422
valid: Mapped[bool] = mapped_column(nullable=False, default=True)
2523
completed: Mapped[bool] = mapped_column(nullable=False, default=False)
2624

27-
relationship("AlertFlowGraph", cascade="all, delete")
25+
relationship('AlertFlowGraph', cascade='all, delete')
2826

2927
def __repr__(self):
30-
return (
31-
f"<AlertResults(alert_id={self.alert_id}, repo={self.repo}, "
32-
f"rule={self.rule}, language={self.language}, location={self.location}, "
33-
f"result={self.result}, created_at={self.created}, valid={self.valid}, completed={self.completed})>"
34-
)
35-
28+
return (f"<AlertResults(alert_id={self.alert_id}, repo={self.repo}, "
29+
f"rule={self.rule}, language={self.language}, location={self.location}, "
30+
f"result={self.result}, created_at={self.created}, valid={self.valid}, completed={self.completed})>")
3631

3732
class AlertFlowGraph(Base):
38-
__tablename__ = "alert_flow_graph"
33+
__tablename__ = 'alert_flow_graph'
3934

4035
id: Mapped[int] = mapped_column(primary_key=True)
41-
alert_canonical_id = Column(Integer, ForeignKey("alert_results.canonical_id", ondelete="CASCADE"))
36+
alert_canonical_id = Column(Integer, ForeignKey('alert_results.canonical_id', ondelete='CASCADE'))
4237
flow_data: Mapped[str] = mapped_column(Text)
4338
repo: Mapped[str]
4439
prev: Mapped[Optional[str]]
4540
next: Mapped[Optional[str]]
4641
started: Mapped[bool] = mapped_column(nullable=False, default=False)
4742

4843
def __repr__(self):
49-
return (
50-
f"<AlertFlowGraph(alert_canonical_id={self.alert_canonical_id}, "
51-
f"flow_data={self.flow_data}, repo={self.repo}, prev={self.prev}, next={self.next}, started={self.started})>"
52-
)
44+
return (f"<AlertFlowGraph(alert_canonical_id={self.alert_canonical_id}, "
45+
f"flow_data={self.flow_data}, repo={self.repo}, prev={self.prev}, next={self.next}, started={self.started})>")
46+

src/seclab_taskflows/mcp_servers/codeql_python/codeql_sqlite_models.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@
55
from sqlalchemy.orm import DeclarativeBase, mapped_column, Mapped
66
from typing import Optional
77

8-
98
class Base(DeclarativeBase):
109
pass
1110

1211

1312
class Source(Base):
14-
__tablename__ = "source"
13+
__tablename__ = 'source'
1514

1615
id: Mapped[int] = mapped_column(primary_key=True)
1716
repo: Mapped[str]
@@ -21,8 +20,6 @@ class Source(Base):
2120
notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
2221

2322
def __repr__(self):
24-
return (
25-
f"<Source(id={self.id}, repo={self.repo}, "
26-
f"location={self.source_location}, line={self.line}, source_type={self.source_type}, "
27-
f"notes={self.notes})>"
28-
)
23+
return (f"<Source(id={self.id}, repo={self.repo}, "
24+
f"location={self.source_location}, line={self.line}, source_type={self.source_type}, "
25+
f"notes={self.notes})>")

src/seclab_taskflows/mcp_servers/codeql_python/mcp_server.py

Lines changed: 50 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66
from seclab_taskflow_agent.mcp_servers.codeql.client import run_query, _debug_log
77

88
from pydantic import Field
9-
10-
# from mcp.server.fastmcp import FastMCP, Context
11-
from fastmcp import FastMCP # use FastMCP 2.0
9+
#from mcp.server.fastmcp import FastMCP, Context
10+
from fastmcp import FastMCP # use FastMCP 2.0
1211
from pathlib import Path
1312
import os
1413
import csv
@@ -24,20 +23,22 @@
2423

2524
logging.basicConfig(
2625
level=logging.DEBUG,
27-
format="%(asctime)s - %(levelname)s - %(message)s",
28-
filename=log_file_name("mcp_codeql_python.log"),
29-
filemode="a",
26+
format='%(asctime)s - %(levelname)s - %(message)s',
27+
filename=log_file_name('mcp_codeql_python.log'),
28+
filemode='a'
3029
)
3130

32-
MEMORY = mcp_data_dir("seclab-taskflows", "codeql", "DATA_DIR")
33-
CODEQL_DBS_BASE_PATH = mcp_data_dir("seclab-taskflows", "codeql", "CODEQL_DBS_BASE_PATH")
31+
MEMORY = mcp_data_dir('seclab-taskflows', 'codeql', 'DATA_DIR')
32+
CODEQL_DBS_BASE_PATH = mcp_data_dir('seclab-taskflows', 'codeql', 'CODEQL_DBS_BASE_PATH')
3433

3534
mcp = FastMCP("CodeQL-Python")
3635

3736
# tool name -> templated query lookup for supported languages
3837
TEMPLATED_QUERY_PATHS = {
3938
# to add a language, port the templated query pack and add its definition here
40-
"python": {"remote_sources": "queries/mcp-python/remote_sources.ql"}
39+
'python': {
40+
'remote_sources': 'queries/mcp-python/remote_sources.ql'
41+
}
4142
}
4243

4344

@@ -48,10 +49,9 @@ def source_to_dict(result):
4849
"source_location": result.source_location,
4950
"line": result.line,
5051
"source_type": result.source_type,
51-
"notes": result.notes,
52+
"notes": result.notes
5253
}
5354

54-
5555
def _resolve_query_path(language: str, query: str) -> Path:
5656
global TEMPLATED_QUERY_PATHS
5757
if language not in TEMPLATED_QUERY_PATHS:
@@ -66,7 +66,7 @@ def _resolve_db_path(relative_db_path: str | Path):
6666
global CODEQL_DBS_BASE_PATH
6767
# path joins will return "/B" if "/A" / "////B" etc. as well
6868
# not windows compatible and probably needs additional hardening
69-
relative_db_path = str(relative_db_path).strip().lstrip("/")
69+
relative_db_path = str(relative_db_path).strip().lstrip('/')
7070
relative_db_path = Path(relative_db_path)
7171
absolute_path = (CODEQL_DBS_BASE_PATH / relative_db_path).resolve()
7272
if not absolute_path.is_relative_to(CODEQL_DBS_BASE_PATH.resolve()):
@@ -76,38 +76,36 @@ def _resolve_db_path(relative_db_path: str | Path):
7676
raise RuntimeError(f"Error: Database not found at {absolute_path}!")
7777
return str(absolute_path)
7878

79-
8079
# This sqlite database is specifically made for CodeQL for Python MCP.
8180
class CodeqlSqliteBackend:
8281
def __init__(self, memcache_state_dir: str):
8382
self.memcache_state_dir = memcache_state_dir
8483
if not Path(self.memcache_state_dir).exists():
85-
db_dir = "sqlite://"
84+
db_dir = 'sqlite://'
8685
else:
87-
db_dir = f"sqlite:///{self.memcache_state_dir}/codeql_sqlite.db"
86+
db_dir = f'sqlite:///{self.memcache_state_dir}/codeql_sqlite.db'
8887
self.engine = create_engine(db_dir, echo=False)
8988
Base.metadata.create_all(self.engine, tables=[Source.__table__])
9089

91-
def store_new_source(self, repo, source_location, line, source_type, notes, update=False):
90+
91+
def store_new_source(self, repo, source_location, line, source_type, notes, update = False):
9292
with Session(self.engine) as session:
93-
existing = session.query(Source).filter_by(repo=repo, source_location=source_location, line=line).first()
93+
existing = session.query(Source).filter_by(repo = repo, source_location = source_location, line = line).first()
9494
if existing:
9595
existing.notes = (existing.notes or "") + notes
9696
session.commit()
9797
return f"Updated notes for source at {source_location}, line {line} in {repo}."
9898
else:
9999
if update:
100100
return f"No source exists at repo {repo}, location {source_location}, line {line} to update."
101-
new_source = Source(
102-
repo=repo, source_location=source_location, line=line, source_type=source_type, notes=notes
103-
)
101+
new_source = Source(repo = repo, source_location = source_location, line = line, source_type = source_type, notes = notes)
104102
session.add(new_source)
105103
session.commit()
106104
return f"Added new source for {source_location} in {repo}."
107105

108106
def get_sources(self, repo):
109107
with Session(self.engine) as session:
110-
results = session.query(Source).filter_by(repo=repo).all()
108+
results = session.query(Source).filter_by(repo = repo).all()
111109
sources = [source_to_dict(source) for source in results]
112110
return sources
113111

@@ -121,8 +119,8 @@ def _csv_parse(raw):
121119
if i == 0:
122120
continue
123121
# col1 has what we care about, but offer flexibility
124-
keys = row[1].split(",")
125-
this_obj = {"description": row[0].format(*row[2:])}
122+
keys = row[1].split(',')
123+
this_obj = {'description': row[0].format(*row[2:])}
126124
for j, k in enumerate(keys):
127125
this_obj[k.strip()] = row[j + 2]
128126
results.append(this_obj)
@@ -143,32 +141,27 @@ def _run_query(query_name: str, database_path: str, language: str, template_valu
143141
except RuntimeError:
144142
return f"The query {query_name} is not supported for language: {language}"
145143
try:
146-
csv = run_query(
147-
Path(__file__).parent.resolve() / query_path,
148-
database_path,
149-
fmt="csv",
150-
template_values=template_values,
151-
log_stderr=True,
152-
)
144+
csv = run_query(Path(__file__).parent.resolve() /
145+
query_path,
146+
database_path,
147+
fmt='csv',
148+
template_values=template_values,
149+
log_stderr=True)
153150
return _csv_parse(csv)
154151
except Exception as e:
155152
return f"The query {query_name} encountered an error: {e}"
156153

157-
158154
backend = CodeqlSqliteBackend(MEMORY)
159155

160-
161156
@mcp.tool()
162-
def remote_sources(
163-
owner: str = Field(description="The owner of the GitHub repository"),
164-
repo: str = Field(description="The name of the GitHub repository"),
165-
database_path: str = Field(description="The CodeQL database path."),
166-
language: str = Field(description="The language used for the CodeQL database."),
167-
):
157+
def remote_sources(owner: str = Field(description="The owner of the GitHub repository"),
158+
repo: str = Field(description="The name of the GitHub repository"),
159+
database_path: str = Field(description="The CodeQL database path."),
160+
language: str = Field(description="The language used for the CodeQL database.")):
168161
"""List all remote sources and their locations in a CodeQL database, then store the results in a database."""
169162

170163
repo = process_repo(owner, repo)
171-
results = _run_query("remote_sources", database_path, language, {})
164+
results = _run_query('remote_sources', database_path, language, {})
172165

173166
# Check if results is an error (list of strings) or valid data (list of dicts)
174167
if isinstance(results, str):
@@ -179,67 +172,53 @@ def remote_sources(
179172
for result in results:
180173
backend.store_new_source(
181174
repo=repo,
182-
source_location=result.get("location", ""),
183-
source_type=result.get("source", ""),
184-
line=int(result.get("line", "0")),
185-
notes=None, # result.get('description', ''),
186-
update=False,
175+
source_location=result.get('location', ''),
176+
source_type=result.get('source', ''),
177+
line=int(result.get('line', '0')),
178+
notes=None, #result.get('description', ''),
179+
update=False
187180
)
188181
stored_count += 1
189182

190183
return f"Stored {stored_count} remote sources in {repo}."
191184

192-
193185
@mcp.tool()
194-
def fetch_sources(
195-
owner: str = Field(description="The owner of the GitHub repository"),
196-
repo: str = Field(description="The name of the GitHub repository"),
197-
):
186+
def fetch_sources(owner: str = Field(description="The owner of the GitHub repository"),
187+
repo: str = Field(description="The name of the GitHub repository")):
198188
"""
199189
Fetch all sources from the repo
200190
"""
201191
repo = process_repo(owner, repo)
202192
return json.dumps(backend.get_sources(repo))
203193

204-
205194
@mcp.tool()
206-
def add_source_notes(
207-
owner: str = Field(description="The owner of the GitHub repository"),
208-
repo: str = Field(description="The name of the GitHub repository"),
209-
source_location: str = Field(description="The path to the file"),
210-
line: int = Field(description="The line number of the source"),
211-
notes: str = Field(description="The notes to append to this source"),
212-
):
195+
def add_source_notes(owner: str = Field(description="The owner of the GitHub repository"),
196+
repo: str = Field(description="The name of the GitHub repository"),
197+
source_location: str = Field(description="The path to the file"),
198+
line: int = Field(description="The line number of the source"),
199+
notes: str = Field(description="The notes to append to this source")):
213200
"""
214201
Add new notes to an existing source. The notes will be appended to any existing notes.
215202
"""
216203
repo = process_repo(owner, repo)
217-
return backend.store_new_source(
218-
repo=repo, source_location=source_location, line=line, source_type="", notes=notes, update=True
219-
)
220-
204+
return backend.store_new_source(repo = repo, source_location = source_location, line = line, source_type = "", notes = notes, update=True)
221205

222206
@mcp.tool()
223-
def clear_codeql_repo(
224-
owner: str = Field(description="The owner of the GitHub repository"),
225-
repo: str = Field(description="The name of the GitHub repository"),
226-
):
207+
def clear_codeql_repo(owner: str = Field(description="The owner of the GitHub repository"),
208+
repo: str = Field(description="The name of the GitHub repository")):
227209
"""
228210
Clear all data for a given repo from the database
229211
"""
230212
repo = process_repo(owner, repo)
231213
with Session(backend.engine) as session:
232-
deleted_sources = session.query(Source).filter_by(repo=repo).delete()
214+
deleted_sources = session.query(Source).filter_by(repo = repo).delete()
233215
session.commit()
234216
return f"Cleared {deleted_sources} sources from repo {repo}."
235217

236-
237218
if __name__ == "__main__":
238219
# Check if codeql/python-all pack is installed, if not install it
239-
if not os.path.isdir("/.codeql/packages/codeql/python-all"):
240-
pack_path = importlib.resources.files("seclab_taskflows.mcp_servers.codeql_python.queries").joinpath(
241-
"mcp-python"
242-
)
220+
if not os.path.isdir('/.codeql/packages/codeql/python-all'):
221+
pack_path = importlib.resources.files('seclab_taskflows.mcp_servers.codeql_python.queries').joinpath('mcp-python')
243222
print(f"Installing CodeQL pack from {pack_path}")
244223
subprocess.run(["codeql", "pack", "install", pack_path])
245224
mcp.run(show_banner=False, transport="http", host="127.0.0.1", port=9998)

0 commit comments

Comments
 (0)