Skip to content
This repository was archived by the owner on Mar 6, 2026. It is now read-only.

Commit 9be38ef

Browse files
authored
Merge branch 'autogen' into refactor/cleans-up-minor-issues-that-prevent-bqclient-build
2 parents bef480e + 66504db commit 9be38ef

2 files changed

Lines changed: 256 additions & 2 deletions

File tree

scripts/microgenerator/generate.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,13 @@ def _get_type_str(self, node: ast.AST | None) -> str | None:
8585
# Handles forward references as strings, e.g., '"Dataset"'
8686
if isinstance(node, ast.Constant):
8787
return repr(node.value)
88+
89+
# Handles | union types, e.g., int | float
90+
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
91+
left_str = self._get_type_str(node.left)
92+
right_str = self._get_type_str(node.right)
93+
return f"{left_str} | {right_str}"
94+
8895
return None # Fallback for unhandled types
8996

9097
def _collect_types_from_node(self, node: ast.AST | None) -> None:

scripts/microgenerator/tests/unit/test_generate_analyzer.py

Lines changed: 249 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717
import ast
1818
import pytest
19-
from scripts.microgenerator.generate import CodeAnalyzer
19+
import textwrap as tw
20+
from scripts.microgenerator.generate import parse_code, CodeAnalyzer
2021

2122
# --- Tests CodeAnalyzer handling of Imports ---
2223

@@ -93,7 +94,6 @@ def test_import_extraction(self, code_snippet, expected_imports):
9394

9495

9596
class TestCodeAnalyzerAttributes:
96-
9797
@pytest.mark.parametrize(
9898
"code_snippet, expected_structure",
9999
[
@@ -259,3 +259,250 @@ def test_attribute_extraction(self, code_snippet: str, expected_structure: list)
259259
item["attributes"].sort(key=lambda x: x["name"])
260260

261261
assert extracted == expected_structure
262+
263+
264+
# --- Mock Types ---
265+
class MyClass:
266+
pass
267+
268+
269+
class AnotherClass:
270+
pass
271+
272+
273+
class YetAnotherClass:
274+
pass
275+
276+
277+
def test_codeanalyzer_finds_class():
278+
code = tw.dedent(
279+
"""
280+
class MyClass:
281+
pass
282+
"""
283+
)
284+
analyzer = CodeAnalyzer()
285+
tree = ast.parse(code)
286+
analyzer.visit(tree)
287+
assert len(analyzer.structure) == 1
288+
assert analyzer.structure[0]["class_name"] == "MyClass"
289+
290+
291+
def test_codeanalyzer_finds_multiple_classes():
292+
code = tw.dedent(
293+
"""
294+
class ClassA:
295+
pass
296+
297+
298+
class ClassB:
299+
pass
300+
"""
301+
)
302+
analyzer = CodeAnalyzer()
303+
tree = ast.parse(code)
304+
analyzer.visit(tree)
305+
assert len(analyzer.structure) == 2
306+
class_names = sorted([c["class_name"] for c in analyzer.structure])
307+
assert class_names == ["ClassA", "ClassB"]
308+
309+
310+
def test_codeanalyzer_finds_method():
311+
code = tw.dedent(
312+
"""
313+
class MyClass:
314+
def my_method(self):
315+
pass
316+
"""
317+
)
318+
analyzer = CodeAnalyzer()
319+
tree = ast.parse(code)
320+
analyzer.visit(tree)
321+
assert len(analyzer.structure) == 1
322+
assert len(analyzer.structure[0]["methods"]) == 1
323+
assert analyzer.structure[0]["methods"][0]["method_name"] == "my_method"
324+
325+
326+
def test_codeanalyzer_finds_multiple_methods():
327+
code = tw.dedent(
328+
"""
329+
class MyClass:
330+
def method_a(self):
331+
pass
332+
333+
def method_b(self):
334+
pass
335+
"""
336+
)
337+
analyzer = CodeAnalyzer()
338+
tree = ast.parse(code)
339+
analyzer.visit(tree)
340+
assert len(analyzer.structure) == 1
341+
method_names = sorted([m["method_name"] for m in analyzer.structure[0]["methods"]])
342+
assert method_names == ["method_a", "method_b"]
343+
344+
345+
def test_codeanalyzer_no_classes():
346+
code = tw.dedent(
347+
"""
348+
def top_level_function():
349+
pass
350+
"""
351+
)
352+
analyzer = CodeAnalyzer()
353+
tree = ast.parse(code)
354+
analyzer.visit(tree)
355+
assert len(analyzer.structure) == 0
356+
357+
358+
def test_codeanalyzer_class_with_no_methods():
359+
code = tw.dedent(
360+
"""
361+
class MyClass:
362+
attribute = 123
363+
"""
364+
)
365+
analyzer = CodeAnalyzer()
366+
tree = ast.parse(code)
367+
analyzer.visit(tree)
368+
assert len(analyzer.structure) == 1
369+
assert analyzer.structure[0]["class_name"] == "MyClass"
370+
assert len(analyzer.structure[0]["methods"]) == 0
371+
372+
373+
# --- Test Data for Parameterization ---
374+
TYPE_TEST_CASES = [
375+
pytest.param(
376+
tw.dedent(
377+
"""
378+
class TestClass:
379+
def func(self, a: int, b: str) -> bool: return True
380+
"""
381+
),
382+
[("a", "int"), ("b", "str")],
383+
"bool",
384+
id="simple_types",
385+
),
386+
pytest.param(
387+
tw.dedent(
388+
"""
389+
from typing import Optional
390+
class TestClass:
391+
def func(self, a: Optional[int]) -> str | None: return 'hello'
392+
"""
393+
),
394+
[("a", "Optional[int]")],
395+
"str | None",
396+
id="optional_union_none",
397+
),
398+
pytest.param(
399+
tw.dedent(
400+
"""
401+
from typing import Union
402+
class TestClass:
403+
def func(self, a: int | float, b: Union[str, bytes]) -> None: pass
404+
"""
405+
),
406+
[("a", "int | float"), ("b", "Union[str, bytes]")],
407+
"None",
408+
id="union_types",
409+
),
410+
pytest.param(
411+
tw.dedent(
412+
"""
413+
from typing import List, Dict, Tuple
414+
class TestClass:
415+
def func(self, a: List[int], b: Dict[str, float]) -> Tuple[int, str]: return (1, 'a')
416+
"""
417+
),
418+
[("a", "List[int]"), ("b", "Dict[str, float]")],
419+
"Tuple[int, str]",
420+
id="generic_types",
421+
),
422+
pytest.param(
423+
tw.dedent(
424+
"""
425+
import datetime
426+
from scripts.microgenerator.tests.unit.test_generate_analyzer import MyClass
427+
class TestClass:
428+
def func(self, a: datetime.date, b: MyClass) -> MyClass: return b
429+
"""
430+
),
431+
[("a", "datetime.date"), ("b", "MyClass")],
432+
"MyClass",
433+
id="imported_types",
434+
),
435+
pytest.param(
436+
tw.dedent(
437+
"""
438+
from scripts.microgenerator.tests.unit.test_generate_analyzer import AnotherClass, YetAnotherClass
439+
class TestClass:
440+
def func(self, a: 'AnotherClass') -> 'YetAnotherClass': return AnotherClass()
441+
"""
442+
),
443+
[("a", "'AnotherClass'")],
444+
"'YetAnotherClass'",
445+
id="forward_refs",
446+
),
447+
pytest.param(
448+
tw.dedent(
449+
"""
450+
class TestClass:
451+
def func(self, a, b): return a + b
452+
"""
453+
),
454+
[("a", None), ("b", None)], # No annotations means type is None
455+
None,
456+
id="no_annotations",
457+
),
458+
pytest.param(
459+
tw.dedent(
460+
"""
461+
from typing import List, Optional, Dict, Union, Any
462+
class TestClass:
463+
def func(self, a: List[Optional[Dict[str, Union[int, str]]]]) -> Dict[str, Any]: return {}
464+
"""
465+
),
466+
[("a", "List[Optional[Dict[str, Union[int, str]]]]")],
467+
"Dict[str, Any]",
468+
id="complex_nested",
469+
),
470+
pytest.param(
471+
tw.dedent(
472+
"""
473+
from typing import Literal
474+
class TestClass:
475+
def func(self, a: Literal['one', 'two']) -> Literal[True]: return True
476+
"""
477+
),
478+
[("a", "Literal['one', 'two']")],
479+
"Literal[True]",
480+
id="literal_type",
481+
),
482+
]
483+
484+
485+
class TestCodeAnalyzerArgsReturns:
486+
@pytest.mark.parametrize(
487+
"code_snippet, expected_args, expected_return", TYPE_TEST_CASES
488+
)
489+
def test_type_extraction(self, code_snippet, expected_args, expected_return):
490+
structure, imports, types = parse_code(code_snippet)
491+
492+
assert len(structure) == 1, "Should parse one class"
493+
class_info = structure[0]
494+
assert class_info["class_name"] == "TestClass"
495+
496+
assert len(class_info["methods"]) == 1, "Should find one method"
497+
method_info = class_info["methods"][0]
498+
assert method_info["method_name"] == "func"
499+
500+
# Extract args, skipping 'self'
501+
extracted_args = []
502+
for arg in method_info.get("args", []):
503+
if arg["name"] == "self":
504+
continue
505+
extracted_args.append((arg["name"], arg["type"]))
506+
507+
assert extracted_args == expected_args
508+
assert method_info.get("return_type") == expected_return

0 commit comments

Comments
 (0)