|
| 1 | +import argparse |
| 2 | +import importlib.util |
| 3 | +import pathlib |
| 4 | +import shutil |
| 5 | +import sys |
| 6 | + |
| 7 | +_PROJECT_DIR = pathlib.Path(__file__).resolve().parents[1] |
| 8 | +_OPS_DIR = _PROJECT_DIR / "src" / "ninetoothed" / "ops" |
| 9 | + |
| 10 | + |
| 11 | +def _find_op_modules(): |
| 12 | + return { |
| 13 | + path.parent.name: path |
| 14 | + for path in sorted(_OPS_DIR.glob("*/build.py")) |
| 15 | + if path.is_file() |
| 16 | + } |
| 17 | + |
| 18 | + |
| 19 | +def _build_manifest(output_dir): |
| 20 | + return sorted( |
| 21 | + str(path) |
| 22 | + for path in pathlib.Path(output_dir).rglob("*.cpp") |
| 23 | + if not path.name.endswith(".tmp.cpp") |
| 24 | + ) |
| 25 | + |
| 26 | + |
| 27 | +def _write_cmake_manifest(output_dir, sources): |
| 28 | + manifest_path = pathlib.Path(output_dir) / "manifest.cmake" |
| 29 | + lines = ["set(INFINIOPS_NINETOOTHED_SOURCES"] |
| 30 | + lines.extend(f' "{source}"' for source in sources) |
| 31 | + lines.append(")") |
| 32 | + lines.append("") |
| 33 | + lines.append(f'set(INFINIOPS_NINETOOTHED_INCLUDE_DIRS "{output_dir}")') |
| 34 | + lines.append("") |
| 35 | + manifest_path.write_text("\n".join(lines) + "\n") |
| 36 | + |
| 37 | + |
| 38 | +def _load_op_module(op): |
| 39 | + path = _find_op_modules()[op] |
| 40 | + sys.path.insert(0, str(path.parent)) |
| 41 | + spec = importlib.util.spec_from_file_location(path.stem, path) |
| 42 | + module = importlib.util.module_from_spec(spec) |
| 43 | + assert spec.loader is not None |
| 44 | + sys.modules[spec.name] = module |
| 45 | + spec.loader.exec_module(module) |
| 46 | + |
| 47 | + return module |
| 48 | + |
| 49 | + |
| 50 | +def generate(ops, *, output_dir): |
| 51 | + op_modules = _find_op_modules() |
| 52 | + unknown_ops = tuple(op for op in ops if op not in op_modules) |
| 53 | + |
| 54 | + if unknown_ops: |
| 55 | + raise ValueError(f"unsupported NineToothed ops: {', '.join(unknown_ops)}") |
| 56 | + |
| 57 | + output_dir = pathlib.Path(output_dir) |
| 58 | + shutil.rmtree(output_dir, ignore_errors=True) |
| 59 | + output_dir.mkdir(parents=True, exist_ok=True) |
| 60 | + |
| 61 | + for op in ops: |
| 62 | + module = _load_op_module(op) |
| 63 | + module.build(output_dir) |
| 64 | + |
| 65 | + sources = _build_manifest(output_dir) |
| 66 | + _write_cmake_manifest(output_dir, sources) |
| 67 | + |
| 68 | + return sources |
| 69 | + |
| 70 | + |
| 71 | +def _parse_args(): |
| 72 | + parser = argparse.ArgumentParser( |
| 73 | + description="Generate NineToothed operator sources for InfiniOps." |
| 74 | + ) |
| 75 | + parser.add_argument("--output-dir", required=True) |
| 76 | + parser.add_argument("--ops", nargs="+", default=tuple(_find_op_modules())) |
| 77 | + |
| 78 | + return parser.parse_args() |
| 79 | + |
| 80 | + |
| 81 | +def main(): |
| 82 | + args = _parse_args() |
| 83 | + generate(args.ops, output_dir=args.output_dir) |
| 84 | + |
| 85 | + |
| 86 | +if __name__ == "__main__": |
| 87 | + main() |
0 commit comments