Skip to content
This repository was archived by the owner on Mar 6, 2026. It is now read-only.

Commit 75e9ecf

Browse files
authored
Merge branch 'autogen' into refactor/rename-variables-edit-docstrings
2 parents 4715617 + 66504db commit 75e9ecf

2 files changed

Lines changed: 257 additions & 1 deletion

File tree

scripts/microgenerator/generate.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,13 @@ def _get_type_str(self, node: ast.AST | None) -> str | None:
8484
# Handles forward references as strings, e.g., '"Dataset"'
8585
if isinstance(node, ast.Constant):
8686
return repr(node.value)
87+
88+
# Handles | union types, e.g., int | float
89+
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
90+
left_str = self._get_type_str(node.left)
91+
right_str = self._get_type_str(node.right)
92+
return f"{left_str} | {right_str}"
93+
8794
return None # Fallback for unhandled types
8895

8996
def _collect_types_from_node(self, node: ast.AST | None) -> None:

scripts/microgenerator/tests/unit/test_generate_analyzer.py

Lines changed: 250 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717
import ast
1818
import pytest
19-
from scripts.microgenerator.generate import CodeAnalyzer
19+
import textwrap as tw
20+
from scripts.microgenerator.generate import parse_code, CodeAnalyzer
2021

2122
# --- Tests CodeAnalyzer handling of Imports ---
2223

@@ -260,3 +261,251 @@ def test_attribute_extraction(
260261
item["attributes"].sort(key=lambda x: x["name"])
261262

262263
assert extracted == expected_analyzed_classes
264+
265+
266+
# --- Mock Types ---
267+
class MyClass:
268+
pass
269+
270+
271+
class AnotherClass:
272+
pass
273+
274+
275+
class YetAnotherClass:
276+
pass
277+
278+
279+
def test_codeanalyzer_finds_class():
280+
code = tw.dedent(
281+
"""
282+
class MyClass:
283+
pass
284+
"""
285+
)
286+
analyzer = CodeAnalyzer()
287+
tree = ast.parse(code)
288+
analyzer.visit(tree)
289+
assert len(analyzer.structure) == 1
290+
assert analyzer.structure[0]["class_name"] == "MyClass"
291+
292+
293+
def test_codeanalyzer_finds_multiple_classes():
294+
code = tw.dedent(
295+
"""
296+
class ClassA:
297+
pass
298+
299+
300+
class ClassB:
301+
pass
302+
"""
303+
)
304+
analyzer = CodeAnalyzer()
305+
tree = ast.parse(code)
306+
analyzer.visit(tree)
307+
assert len(analyzer.structure) == 2
308+
class_names = sorted([c["class_name"] for c in analyzer.structure])
309+
assert class_names == ["ClassA", "ClassB"]
310+
311+
312+
def test_codeanalyzer_finds_method():
313+
code = tw.dedent(
314+
"""
315+
class MyClass:
316+
def my_method(self):
317+
pass
318+
"""
319+
)
320+
analyzer = CodeAnalyzer()
321+
tree = ast.parse(code)
322+
analyzer.visit(tree)
323+
assert len(analyzer.structure) == 1
324+
assert len(analyzer.structure[0]["methods"]) == 1
325+
assert analyzer.structure[0]["methods"][0]["method_name"] == "my_method"
326+
327+
328+
def test_codeanalyzer_finds_multiple_methods():
329+
code = tw.dedent(
330+
"""
331+
class MyClass:
332+
def method_a(self):
333+
pass
334+
335+
def method_b(self):
336+
pass
337+
"""
338+
)
339+
analyzer = CodeAnalyzer()
340+
tree = ast.parse(code)
341+
analyzer.visit(tree)
342+
assert len(analyzer.structure) == 1
343+
method_names = sorted([m["method_name"] for m in analyzer.structure[0]["methods"]])
344+
assert method_names == ["method_a", "method_b"]
345+
346+
347+
def test_codeanalyzer_no_classes():
348+
code = tw.dedent(
349+
"""
350+
def top_level_function():
351+
pass
352+
"""
353+
)
354+
analyzer = CodeAnalyzer()
355+
tree = ast.parse(code)
356+
analyzer.visit(tree)
357+
assert len(analyzer.structure) == 0
358+
359+
360+
def test_codeanalyzer_class_with_no_methods():
361+
code = tw.dedent(
362+
"""
363+
class MyClass:
364+
attribute = 123
365+
"""
366+
)
367+
analyzer = CodeAnalyzer()
368+
tree = ast.parse(code)
369+
analyzer.visit(tree)
370+
assert len(analyzer.structure) == 1
371+
assert analyzer.structure[0]["class_name"] == "MyClass"
372+
assert len(analyzer.structure[0]["methods"]) == 0
373+
374+
375+
# --- Test Data for Parameterization ---
376+
TYPE_TEST_CASES = [
377+
pytest.param(
378+
tw.dedent(
379+
"""
380+
class TestClass:
381+
def func(self, a: int, b: str) -> bool: return True
382+
"""
383+
),
384+
[("a", "int"), ("b", "str")],
385+
"bool",
386+
id="simple_types",
387+
),
388+
pytest.param(
389+
tw.dedent(
390+
"""
391+
from typing import Optional
392+
class TestClass:
393+
def func(self, a: Optional[int]) -> str | None: return 'hello'
394+
"""
395+
),
396+
[("a", "Optional[int]")],
397+
"str | None",
398+
id="optional_union_none",
399+
),
400+
pytest.param(
401+
tw.dedent(
402+
"""
403+
from typing import Union
404+
class TestClass:
405+
def func(self, a: int | float, b: Union[str, bytes]) -> None: pass
406+
"""
407+
),
408+
[("a", "int | float"), ("b", "Union[str, bytes]")],
409+
"None",
410+
id="union_types",
411+
),
412+
pytest.param(
413+
tw.dedent(
414+
"""
415+
from typing import List, Dict, Tuple
416+
class TestClass:
417+
def func(self, a: List[int], b: Dict[str, float]) -> Tuple[int, str]: return (1, 'a')
418+
"""
419+
),
420+
[("a", "List[int]"), ("b", "Dict[str, float]")],
421+
"Tuple[int, str]",
422+
id="generic_types",
423+
),
424+
pytest.param(
425+
tw.dedent(
426+
"""
427+
import datetime
428+
from scripts.microgenerator.tests.unit.test_generate_analyzer import MyClass
429+
class TestClass:
430+
def func(self, a: datetime.date, b: MyClass) -> MyClass: return b
431+
"""
432+
),
433+
[("a", "datetime.date"), ("b", "MyClass")],
434+
"MyClass",
435+
id="imported_types",
436+
),
437+
pytest.param(
438+
tw.dedent(
439+
"""
440+
from scripts.microgenerator.tests.unit.test_generate_analyzer import AnotherClass, YetAnotherClass
441+
class TestClass:
442+
def func(self, a: 'AnotherClass') -> 'YetAnotherClass': return AnotherClass()
443+
"""
444+
),
445+
[("a", "'AnotherClass'")],
446+
"'YetAnotherClass'",
447+
id="forward_refs",
448+
),
449+
pytest.param(
450+
tw.dedent(
451+
"""
452+
class TestClass:
453+
def func(self, a, b): return a + b
454+
"""
455+
),
456+
[("a", None), ("b", None)], # No annotations means type is None
457+
None,
458+
id="no_annotations",
459+
),
460+
pytest.param(
461+
tw.dedent(
462+
"""
463+
from typing import List, Optional, Dict, Union, Any
464+
class TestClass:
465+
def func(self, a: List[Optional[Dict[str, Union[int, str]]]]) -> Dict[str, Any]: return {}
466+
"""
467+
),
468+
[("a", "List[Optional[Dict[str, Union[int, str]]]]")],
469+
"Dict[str, Any]",
470+
id="complex_nested",
471+
),
472+
pytest.param(
473+
tw.dedent(
474+
"""
475+
from typing import Literal
476+
class TestClass:
477+
def func(self, a: Literal['one', 'two']) -> Literal[True]: return True
478+
"""
479+
),
480+
[("a", "Literal['one', 'two']")],
481+
"Literal[True]",
482+
id="literal_type",
483+
),
484+
]
485+
486+
487+
class TestCodeAnalyzerArgsReturns:
488+
@pytest.mark.parametrize(
489+
"code_snippet, expected_args, expected_return", TYPE_TEST_CASES
490+
)
491+
def test_type_extraction(self, code_snippet, expected_args, expected_return):
492+
structure, imports, types = parse_code(code_snippet)
493+
494+
assert len(structure) == 1, "Should parse one class"
495+
class_info = structure[0]
496+
assert class_info["class_name"] == "TestClass"
497+
498+
assert len(class_info["methods"]) == 1, "Should find one method"
499+
method_info = class_info["methods"][0]
500+
assert method_info["method_name"] == "func"
501+
502+
# Extract args, skipping 'self'
503+
extracted_args = []
504+
for arg in method_info.get("args", []):
505+
if arg["name"] == "self":
506+
continue
507+
extracted_args.append((arg["name"], arg["type"]))
508+
509+
assert extracted_args == expected_args
510+
assert method_info.get("return_type") == expected_return
511+

0 commit comments

Comments
 (0)