diff --git a/scripts/microgenerator/generate.py b/scripts/microgenerator/generate.py index bf14726a5..33d859df7 100644 --- a/scripts/microgenerator/generate.py +++ b/scripts/microgenerator/generate.py @@ -24,9 +24,13 @@ import ast import os +import glob +import logging +import re from collections import defaultdict -from typing import List, Dict, Any +from typing import List, Dict, Any, Iterator +from . import name_utils from . import utils # ============================================================================= @@ -352,3 +356,139 @@ def process_structure( return sorted(all_class_keys) else: return dict(results) + + +# ============================================================================= +# Section 2: Source file data gathering +# ============================================================================= + + +def _should_include_class(class_name: str, class_filters: Dict[str, Any]) -> bool: + """Checks if a class should be included based on filter criteria.""" + if class_filters.get("include_suffixes"): + if not class_name.endswith(tuple(class_filters["include_suffixes"])): + return False + if class_filters.get("exclude_suffixes"): + if class_name.endswith(tuple(class_filters["exclude_suffixes"])): + return False + return True + + +def _should_include_method(method_name: str, method_filters: Dict[str, Any]) -> bool: + """Checks if a method should be included based on filter criteria.""" + if method_filters.get("include_prefixes"): + if not any( + method_name.startswith(p) for p in method_filters["include_prefixes"] + ): + return False + if method_filters.get("exclude_prefixes"): + if any(method_name.startswith(p) for p in method_filters["exclude_prefixes"]): + return False + return True + + +def _build_request_arg_schema( + source_files: List[str], project_root: str +) -> Dict[str, List[str]]: + """Parses type files to build a schema of request classes and their _id arguments.""" + request_arg_schema: Dict[str, List[str]] = {} + for file_path in source_files: + if "/types/" not in file_path: + continue + + # Correctly determine the module name from the file path + relative_path = os.path.relpath(file_path, project_root) + module_name = os.path.splitext(relative_path)[0].replace(os.path.sep, ".") + + try: + structure, _, _ = parse_file(file_path) + if not structure: + continue + + for class_info in structure: + class_name = class_info.get("class_name", "Unknown") + if class_name.endswith("Request"): + full_class_name = f"{module_name}.{class_name}" + id_args = [ + attr["name"] + for attr in class_info.get("attributes", []) + if attr.get("name", "").endswith("_id") + ] + if id_args: + request_arg_schema[full_class_name] = id_args + except Exception as e: + logging.warning(f"Failed to parse {file_path}: {e}") + return request_arg_schema + + +def _process_service_clients( + source_files: List[str], class_filters: Dict, method_filters: Dict +) -> tuple[defaultdict, set, set]: + """Parses service client files to extract class and method information.""" + parsed_data = defaultdict(dict) + all_imports: set[str] = set() + all_types: set[str] = set() + + for file_path in source_files: + if "/services/" not in file_path: + continue + + structure, imports, types = parse_file(file_path) + all_imports.update(imports) + all_types.update(types) + + for class_info in structure: + class_name = class_info["class_name"] + if not _should_include_class(class_name, class_filters): + continue + + parsed_data[class_name] # Ensure class is in dict + + for method in class_info["methods"]: + method_name = method["method_name"] + if not _should_include_method(method_name, method_filters): + continue + parsed_data[class_name][method_name] = method + return parsed_data, all_imports, all_types + + +def analyze_source_files( + config: Dict[str, Any], +) -> tuple[Dict[str, Any], set[str], set[str], Dict[str, List[str]]]: + """ + Analyzes source files per the configuration to extract class and method info, + as well as information on imports and typehints. + + Args: + config: The generator's configuration dictionary. + + Returns: + A tuple containing: + - A dictionary containing the data needed for template rendering. + - A set of all import statements required by the parsed methods. + - A set of all type annotations found in the parsed methods. + - A dictionary mapping request class names to their `_id` arguments. + """ + project_root = config["project_root"] + source_patterns_dict = config.get("source_files", {}) + filter_rules = config.get("filter", {}) + class_filters = filter_rules.get("classes", {}) + method_filters = filter_rules.get("methods", {}) + + source_files = [] + for group in source_patterns_dict.values(): + for pattern in group: + # Make the pattern absolute + absolute_pattern = os.path.join(project_root, pattern) + source_files.extend(glob.glob(absolute_pattern, recursive=True)) + + # PASS 1: Build the request argument schema from the types files. + request_arg_schema = _build_request_arg_schema(source_files, project_root) + + # PASS 2: Process the service client files. + parsed_data, all_imports, all_types = _process_service_clients( + source_files, class_filters, method_filters + ) + + return parsed_data, all_imports, all_types, request_arg_schema +