Skip to content

Commit 758b099

Browse files
committed
classify tests
Signed-off-by: Abhishek Kumar <abhishek.mrt22@gmail.com>
1 parent 30e6c22 commit 758b099

File tree

2 files changed

+1506
-0
lines changed

2 files changed

+1506
-0
lines changed
Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Auto-tag Marvin/CloudStack test files into categories:
4+
- business: no VM/Network/VPC/Volume/K8s cluster creation calls
5+
- lowres: creation calls <= number of tests (rough proxy)
6+
- hires: creation calls > number of tests
7+
- infra: host/systemvm/cluster/zone update related
8+
9+
This is heuristic/static analysis: expect a few false positives/negatives.
10+
Use manual overrides by editing the output file if needed.
11+
"""
12+
13+
from __future__ import annotations
14+
15+
import argparse
16+
import ast
17+
import json
18+
import os
19+
import re
20+
from dataclasses import dataclass, asdict
21+
from typing import Iterable, Optional, Set, Dict, List
22+
23+
try:
24+
import yaml # type: ignore
25+
except Exception:
26+
yaml = None # optional
27+
28+
29+
# ---- Heuristic patterns (tune these for your codebase) ----
30+
31+
# "Creation" operations: VM/Network/VPC/Volume/K8s cluster creation.
32+
# We match both direct function names and API/client method names.
33+
CREATION_NAME_PATTERNS = [
34+
r"\bVirtualMachine.create\b|\bdeployVirtualMachine\b",
35+
r"\bNetwork.create\b|\bcreateNetwork\b",
36+
r"\bVPC.create\b|\bVpc.create\b|\bcreateVpc\b|\bcreateVPC\b",
37+
r"\bVolume.create\b|\bcreateVolume\b",
38+
r"\bTemplate.create\b|\bTemplate.create\b|\bISO.create\b",
39+
r"\bcreateKubernetesCluster\b",
40+
# Common helper naming in tests:
41+
r"\bcreate_vm\b|\bdeploy_vm\b|\bdeployVm\b",
42+
r"\bcreate_network\b|\bcreate_vpc\b|\bcreate_volume\b",
43+
r"\bcreate.*network\b|\bcreate.*Network\b|\bdeploy.*network\b|\bdeploy.*Network\b",
44+
r"\bcreate.*vpc\b|\bcreate.*Vpc\b|\bdeploy.*vpc\b|\bdeploy.*Vpc\b",
45+
]
46+
47+
# Infra-related operations: host/systemvm/cluster/zone update, maintenance, etc.
48+
INFRA_NAME_PATTERNS = [
49+
r"\baddHost\b|\bupdateHost\b|\bdeleteHost\b|\bHAForHost\b|\bdeleteHost\b|\breconnectHost\b",
50+
r"\bupdateCluster\b|\baddCluster\b|\bdeleteCluster\b|\bCluster.create\b|\bCluster.update\b|\bCluster.delete\b",
51+
r"\bupdateZone\b|\bcreateZone\b|\bdeleteZone\b",
52+
r"\bSystemVm\b|\bSystemVM\b|\bSSVM\b|\bCPVM\b|\bSecondaryStorageVm\b",
53+
r"\bstartSystemVm\b|\bstopSystemVm\b|\brebootSystemVm\b",
54+
r"\bhostha\b|\bHostHA\b|\bHAForHost\b",
55+
r"\bShutdownCmd\b|\bupdateImageStoreCmd\b|\bStartCommand\b",
56+
r"\bMaintenance\b|\btriggerShutdownCmd\b|\bcancelShutdownCmd\b|\bmaintenance\b",
57+
r"\bprovisionCertificate\b|\bconfigureOutOfBandManagement\b",
58+
]
59+
60+
# Some tests might not use obvious method names but include keywords in strings/comments.
61+
INFRA_TEXT_KEYWORDS = [
62+
"Maintenance", "overprovisioning", "StartCommand", "UPDATE host"
63+
]
64+
65+
66+
# Precompiled regex
67+
CREATION_RE = re.compile("|".join(CREATION_NAME_PATTERNS), re.IGNORECASE)
68+
INFRA_RE = re.compile("|".join(INFRA_NAME_PATTERNS), re.IGNORECASE)
69+
INFRA_TEXT_RE = re.compile("|".join(re.escape(k) for k in INFRA_TEXT_KEYWORDS), re.IGNORECASE)
70+
71+
72+
@dataclass
73+
class FileTagging:
74+
file: str
75+
category: str
76+
num_tests: int
77+
create_calls: int
78+
infra_hits: int
79+
notes: List[str]
80+
81+
82+
class Analyzer(ast.NodeVisitor):
83+
def __init__(self, source_text: str) -> None:
84+
self.source_text = source_text
85+
self.num_tests = 0
86+
self.create_calls = 0
87+
self.infra_hits = 0
88+
self.has_test_methods = False
89+
self._pass = 1 # Track which pass we're on
90+
91+
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
92+
if node.name.startswith("test_"):
93+
self.num_tests += 1
94+
self.has_test_methods = True
95+
self.generic_visit(node)
96+
97+
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
98+
# same logic for async tests
99+
if node.name.startswith("test_"):
100+
self.num_tests += 1
101+
self.has_test_methods = True
102+
self.generic_visit(node)
103+
104+
def visit_Call(self, node: ast.Call) -> None:
105+
# Count calls if the file has test methods (determined in first pass)
106+
should_count = self.has_test_methods
107+
108+
call_name = self._get_call_name(node.func)
109+
if call_name:
110+
if should_count and CREATION_RE.search(call_name):
111+
self.create_calls += 1
112+
if should_count and INFRA_RE.search(call_name):
113+
self.infra_hits += 1
114+
115+
self.generic_visit(node)
116+
117+
def visit_Constant(self, node: ast.Constant) -> None:
118+
# Look for infra-ish keywords in string constants as extra signal
119+
if isinstance(node.value, str):
120+
if INFRA_TEXT_RE.search(node.value):
121+
self.infra_hits += 1
122+
self.generic_visit(node)
123+
124+
@staticmethod
125+
def _get_call_name(func: ast.AST) -> Optional[str]:
126+
# Normalize possible call forms:
127+
# - foo(...)
128+
# - obj.foo(...)
129+
# - self.apiClient.createVirtualMachine(...)
130+
if isinstance(func, ast.Name):
131+
return func.id
132+
if isinstance(func, ast.Attribute):
133+
# Build dotted name
134+
parts = []
135+
cur: Optional[ast.AST] = func
136+
while isinstance(cur, ast.Attribute):
137+
parts.append(cur.attr)
138+
cur = cur.value
139+
if isinstance(cur, ast.Name):
140+
parts.append(cur.id)
141+
parts.reverse()
142+
return ".".join(parts)
143+
return None
144+
145+
146+
def categorize(num_tests: int, create_calls: int, infra_hits: int) -> (str, List[str]):
147+
notes: List[str] = []
148+
149+
# Priority 1: Infra wins - any infra pattern marks it as infra
150+
if infra_hits > 0:
151+
return "infra", [f"infra_hits={infra_hits}"]
152+
153+
# Priority 2: Business - no creation operations found
154+
if create_calls == 0:
155+
return "business", ["no creation calls found"]
156+
157+
# Priority 3, 4, 5: Lowres, Midres, or Hires - based on create_calls count
158+
# If we have creation calls, we need at least one test to classify
159+
if num_tests <= 0:
160+
# Edge case: has creation calls but no test_ functions - still classify based on calls
161+
if create_calls == 1:
162+
return "lowres", [f"no test_ functions found; create_calls={create_calls}"]
163+
elif create_calls <= 2: # midres threshold when no tests
164+
return "midres", [f"no test_ functions found; create_calls={create_calls}"]
165+
else:
166+
return "hires", [f"no test_ functions found; create_calls={create_calls}"]
167+
168+
# Classify based on create_calls vs num_tests
169+
if create_calls == 1:
170+
return "lowres", [f"create_calls({create_calls}) == 1"]
171+
elif create_calls <= num_tests:
172+
return "midres", [f"1 < create_calls({create_calls}) <= num_tests({num_tests})"]
173+
else:
174+
return "hires", [f"create_calls({create_calls}) > num_tests({num_tests})"]
175+
176+
177+
def iter_test_files(root: str) -> Iterable[str]:
178+
for dirpath, _, filenames in os.walk(root):
179+
for fn in filenames:
180+
if fn.startswith("test_") and fn.endswith(".py"):
181+
yield os.path.join(dirpath, fn)
182+
183+
184+
def analyze_file(path: str) -> Optional[FileTagging]:
185+
try:
186+
with open(path, "r", encoding="utf-8") as f:
187+
text = f.read()
188+
tree = ast.parse(text, filename=path)
189+
except SyntaxError as e:
190+
return FileTagging(
191+
file=path,
192+
category="unknown",
193+
num_tests=0,
194+
create_calls=0,
195+
infra_hits=0,
196+
notes=[f"syntax error: {e}"],
197+
)
198+
except Exception as e:
199+
return FileTagging(
200+
file=path,
201+
category="unknown",
202+
num_tests=0,
203+
create_calls=0,
204+
infra_hits=0,
205+
notes=[f"read/parse error: {e}"],
206+
)
207+
208+
# First pass: count test methods and determine if file is a test file
209+
a = Analyzer(text)
210+
211+
# Scan for test methods first
212+
class TestMethodScanner(ast.NodeVisitor):
213+
def __init__(self):
214+
self.has_test_methods = False
215+
self.num_tests = 0
216+
217+
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
218+
if node.name.startswith("test_"):
219+
self.num_tests += 1
220+
self.has_test_methods = True
221+
self.generic_visit(node)
222+
223+
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
224+
if node.name.startswith("test_"):
225+
self.num_tests += 1
226+
self.has_test_methods = True
227+
self.generic_visit(node)
228+
229+
scanner = TestMethodScanner()
230+
scanner.visit(tree)
231+
a.has_test_methods = scanner.has_test_methods
232+
a.num_tests = scanner.num_tests
233+
234+
# Second pass: count creation/infra calls now that we know if it's a test file
235+
a.visit(tree)
236+
237+
cat, notes = categorize(a.num_tests, a.create_calls, a.infra_hits)
238+
return FileTagging(
239+
file=path,
240+
category=cat,
241+
num_tests=a.num_tests,
242+
create_calls=a.create_calls,
243+
infra_hits=a.infra_hits,
244+
notes=notes,
245+
)
246+
247+
248+
def main() -> None:
249+
ap = argparse.ArgumentParser()
250+
ap.add_argument("root", help="Root directory containing Marvin tests")
251+
ap.add_argument("-o", "--out", default="test_tags.yaml", help="Output file (yaml or json)")
252+
ap.add_argument("--format", choices=["yaml", "json"], default=None, help="Force output format")
253+
args = ap.parse_args()
254+
255+
out_format = args.format
256+
if out_format is None:
257+
out_format = "json" if args.out.lower().endswith(".json") else "yaml"
258+
259+
results: List[FileTagging] = []
260+
for f in sorted(iter_test_files(args.root)):
261+
results.append(analyze_file(f))
262+
263+
payload = {
264+
"root": os.path.abspath(args.root),
265+
"generated_by": "tag_marvin_tests.py",
266+
"categories": {
267+
"business": [],
268+
"lowres": [],
269+
"midres": [],
270+
"hires": [],
271+
"infra": [],
272+
"unknown": [],
273+
},
274+
"details": [asdict(r) for r in results],
275+
}
276+
277+
for r in results:
278+
payload["categories"].setdefault(r.category, [])
279+
payload["categories"][r.category].append(r.file)
280+
281+
if out_format == "yaml":
282+
if yaml is None:
283+
raise SystemExit("PyYAML not installed. Either install pyyaml or use --format json / .json output.")
284+
with open(args.out, "w", encoding="utf-8") as f:
285+
yaml.safe_dump(payload, f, sort_keys=False)
286+
else:
287+
with open(args.out, "w", encoding="utf-8") as f:
288+
json.dump(payload, f, indent=2)
289+
290+
print(f"Wrote {args.out}")
291+
for k, v in payload["categories"].items():
292+
print(f"{k:8s}: {len(v)} files")
293+
294+
295+
if __name__ == "__main__":
296+
main()

0 commit comments

Comments
 (0)