diff --git a/scripts/microgenerator/generate.py b/scripts/microgenerator/generate.py index bd1b1dbb1..ccaff07a4 100644 --- a/scripts/microgenerator/generate.py +++ b/scripts/microgenerator/generate.py @@ -24,11 +24,12 @@ import ast import os +import argparse import glob import logging +import re from collections import defaultdict -from pathlib import Path -from typing import List, Dict, Any +from typing import List, Dict, Any, Iterator from . import name_utils from . import utils @@ -511,7 +512,6 @@ def _generate_import_statement( Returns: A formatted, multi-line import statement string. """ - names = sorted(list(set([item[key] for item in context]))) names_str = ",\n ".join(names) return f"from {package} import (\n {names_str}\n)" @@ -542,7 +542,6 @@ def generate_code(config: Dict[str, Any], analysis_results: tuple) -> None: """ Generates source code files using Jinja2 templates. """ - data, all_imports, all_types, request_arg_schema = analysis_results project_root = config["project_root"] config_dir = config["config_dir"] @@ -618,3 +617,135 @@ def generate_code(config: Dict[str, Any], analysis_results: tuple) -> None: ) utils.write_code_to_file(output_path, final_code) + + +# ============================================================================= +# Section 4: Main Execution +# ============================================================================= + + +def setup_config_and_paths(config_path: str) -> Dict[str, Any]: + """Loads the configuration and sets up necessary paths. + + Args: + config_path: The path to the YAML configuration file. + + Returns: + A dictionary containing the loaded configuration and paths. + """ + + def find_project_root(start_path: str, markers: list[str]) -> str | None: + """Finds the project root by searching upwards for a marker.""" + current_path = os.path.abspath(start_path) + while True: + for marker in markers: + if os.path.exists(os.path.join(current_path, marker)): + return current_path + parent_path = os.path.dirname(current_path) + if parent_path == current_path: # Filesystem root + return None + current_path = parent_path + + # Load configuration from the YAML file. + config = utils.load_config(config_path) + + # Determine the project root. + script_dir = os.path.dirname(os.path.abspath(__file__)) + project_root = find_project_root(script_dir, ["setup.py", ".git"]) + if not project_root: + project_root = os.getcwd() # Fallback to current directory + + # Set paths in the config dictionary. + config["project_root"] = project_root + config["config_dir"] = os.path.dirname(os.path.abspath(config_path)) + + return config + + +def _execute_post_processing(config: Dict[str, Any]): + """ + Executes post-processing steps, such as patching existing files. + """ + project_root = config["project_root"] + post_processing_jobs = config.get("post_processing_templates", []) + + for job in post_processing_jobs: + template_path = os.path.join(config["config_dir"], job["template"]) + target_file_path = os.path.join(project_root, job["target_file"]) + + if not os.path.exists(target_file_path): + logging.warning( + f"Target file {target_file_path} not found, skipping post-processing job." + ) + continue + + # Read the target file + with open(target_file_path, "r") as f: + lines = f.readlines() + + # --- Extract existing imports and __all__ members --- + imports = [] + all_list = [] + all_start_index = -1 + all_end_index = -1 + + for i, line in enumerate(lines): + if line.strip().startswith("from ."): + imports.append(line.strip()) + if line.strip() == "__all__ = (": + all_start_index = i + if all_start_index != -1 and line.strip() == ")": + all_end_index = i + + if all_start_index != -1 and all_end_index != -1: + for i in range(all_start_index + 1, all_end_index): + member = lines[i].strip().replace('"', "").replace(",", "") + if member: + all_list.append(member) + + # --- Add new items and sort --- + for new_import in job.get("add_imports", []): + if new_import not in imports: + imports.append(new_import) + imports.sort() + imports = [f"{imp}\n" for imp in imports] # re-add newlines + + for new_member in job.get("add_to_all", []): + if new_member not in all_list: + all_list.append(new_member) + all_list.sort() + + # --- Render the new file content --- + template = utils.load_template(template_path) + new_content = template.render( + imports=imports, + all_list=all_list, + ) + + # --- Overwrite the target file --- + with open(target_file_path, "w") as f: + f.write(new_content) + + logging.info(f"Successfully post-processed and overwrote {target_file_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="A generic Python code generator for clients." + ) + parser.add_argument("config", help="Path to the YAML configuration file.") + args = parser.parse_args() + + # Load config and set up paths. + config = setup_config_and_paths(args.config) + + # Analyze the source code. + analysis_results = analyze_source_files(config) + + # Generate the new client code. + generate_code(config, analysis_results) + + # Run post-processing steps. + _execute_post_processing(config) + + # TODO: Ensure blacken gets called on the generated source files as a final step.