diff --git a/scripts/microgenerator/generate.py b/scripts/microgenerator/generate.py index 33d859df7..bd1b1dbb1 100644 --- a/scripts/microgenerator/generate.py +++ b/scripts/microgenerator/generate.py @@ -26,9 +26,9 @@ import os import glob import logging -import re from collections import defaultdict -from typing import List, Dict, Any, Iterator +from pathlib import Path +from typing import List, Dict, Any from . import name_utils from . import utils @@ -492,3 +492,129 @@ def analyze_source_files( return parsed_data, all_imports, all_types, request_arg_schema + +# ============================================================================= +# Section 3: Code Generation +# ============================================================================= + + +def _generate_import_statement( + context: List[Dict[str, Any]], key: str, package: str +) -> str: + """Generates a formatted import statement from a list of context dictionaries. + + Args: + context: A list of dictionaries containing the data. + key: The key to extract from each dictionary in the context. + package: The base import package (e.g., "google.cloud.bigquery_v2.services"). + + Returns: + A formatted, multi-line import statement string. + """ + + names = sorted(list(set([item[key] for item in context]))) + names_str = ",\n ".join(names) + return f"from {package} import (\n {names_str}\n)" + + +def _get_request_class_name(method_name: str, config: Dict[str, Any]) -> str: + """Gets the inferred request class name, applying overrides from config.""" + inferred_request_name = name_utils.method_to_request_class_name(method_name) + method_overrides = config.get("filter", {}).get("methods", {}).get("overrides", {}) + if method_name in method_overrides: + return method_overrides[method_name].get( + "request_class_name", inferred_request_name + ) + return inferred_request_name + + +def _find_fq_request_name( + request_name: str, request_arg_schema: Dict[str, List[str]] +) -> str: + """Finds the fully qualified request name in the schema.""" + for key in request_arg_schema.keys(): + if key.endswith(f".{request_name}"): + return key + return "" + + +def generate_code(config: Dict[str, Any], analysis_results: tuple) -> None: + """ + Generates source code files using Jinja2 templates. + """ + + data, all_imports, all_types, request_arg_schema = analysis_results + project_root = config["project_root"] + config_dir = config["config_dir"] + + templates_config = config.get("templates", []) + for item in templates_config: + template_path = str(Path(config_dir) / item["template"]) + output_path = str(Path(project_root) / item["output"]) + + template = utils.load_template(template_path) + methods_context = [] + for class_name, methods in data.items(): + for method_name, method_info in methods.items(): + context = { + "name": method_name, + "class_name": class_name, + "return_type": method_info["return_type"], + } + + request_name = _get_request_class_name(method_name, config) + fq_request_name = _find_fq_request_name( + request_name, request_arg_schema + ) + + if fq_request_name: + context["request_class_full_name"] = fq_request_name + context["request_id_args"] = request_arg_schema[fq_request_name] + + methods_context.append(context) + + # Prepare imports for the template + services_context = [] + client_class_names = sorted( + list(set([m["class_name"] for m in methods_context])) + ) + + for class_name in client_class_names: + service_name_cluster = name_utils.generate_service_names(class_name) + services_context.append(service_name_cluster) + + # Also need to update methods_context to include the service_name and module_name + # so the template knows which client to use for each method. + class_to_service_map = {s["service_client_class"]: s for s in services_context} + for method in methods_context: + service_info = class_to_service_map.get(method["class_name"]) + if service_info: + method["service_name"] = service_info["service_name"] + method["service_module_name"] = service_info["service_module_name"] + + # Prepare new imports + service_imports = [ + _generate_import_statement( + services_context, + "service_module_name", + "google.cloud.bigquery_v2.services", + ) + ] + + # Prepare type imports + type_imports = [ + _generate_import_statement( + services_context, "service_name", "google.cloud.bigquery_v2.types" + ) + ] + + final_code = template.render( + service_name=config.get("service_name"), + methods=methods_context, + services=services_context, + service_imports=service_imports, + type_imports=type_imports, + request_arg_schema=request_arg_schema, + ) + + utils.write_code_to_file(output_path, final_code)