|
| 1 | +import tarfile |
| 2 | +import io |
| 3 | +from typing import Type, Callable |
| 4 | +from fnnx.variants.pyfunc import PyFunc |
| 5 | +import json |
| 6 | +from dataclasses import dataclass |
| 7 | +from dataclasses import asdict as dataclass_asdict |
| 8 | +import inspect |
| 9 | +import sys |
| 10 | + |
| 11 | +from pydantic import BaseModel as PydanticBaseModel |
| 12 | + |
| 13 | +from fnnx import __version__ as fnnx_version |
| 14 | + |
| 15 | +from fnnx.extras.pydantic_models.manifest import Manifest, NDJSON, JSON, Var |
| 16 | +from fnnx.extras.pydantic_models.variants.pyfunc import PyFuncVariant |
| 17 | +from fnnx.extras.pydantic_models.envs import Python3_CondaPip, PipDependency |
| 18 | + |
| 19 | + |
| 20 | +def asdict(obj): |
| 21 | + if isinstance(obj, PydanticBaseModel): |
| 22 | + return obj.model_dump() |
| 23 | + return dataclass_asdict(obj) |
| 24 | + |
| 25 | + |
| 26 | +def asjson(obj): |
| 27 | + if isinstance(obj, PydanticBaseModel): |
| 28 | + return obj.model_dump_json(indent=4) |
| 29 | + return json.dumps(asdict(obj), indent=4) |
| 30 | + |
| 31 | + |
| 32 | +PYTHON_VERSION = ( |
| 33 | + f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" |
| 34 | +) |
| 35 | + |
| 36 | + |
| 37 | +@dataclass |
| 38 | +class PyFuncSpec: |
| 39 | + filepath: str |
| 40 | + class_name: str |
| 41 | + |
| 42 | + |
| 43 | +class PyfuncBuilder: |
| 44 | + def __init__( |
| 45 | + self, |
| 46 | + pyfunc: Type[PyFunc] | PyFuncSpec, |
| 47 | + model_name: str | None = None, |
| 48 | + model_version: str | None = None, |
| 49 | + model_description: str | None = None, |
| 50 | + create_meta_callback: Callable | None = None, |
| 51 | + ) -> None: |
| 52 | + self._inputs: list[NDJSON | JSON] = [] |
| 53 | + self._outputs: list[NDJSON | JSON] = [] |
| 54 | + self._dynamic_attributes: list[Var] = [] |
| 55 | + self._env_vars: list[Var] = [] |
| 56 | + |
| 57 | + self._producer_name: str = "fnnx.ai" |
| 58 | + self._producer_version: str = fnnx_version |
| 59 | + self._producer_tags: list[str] = [] |
| 60 | + |
| 61 | + self._extra_dtypes: dict = {} |
| 62 | + |
| 63 | + self._name = model_name |
| 64 | + self._version = model_version |
| 65 | + self._description = model_description |
| 66 | + |
| 67 | + self._extra_modules = [] |
| 68 | + self._extra_files = [] |
| 69 | + self._extra_values: dict | None = None |
| 70 | + |
| 71 | + self._build_dependencies = [] |
| 72 | + self._rt_dependencies = [] |
| 73 | + |
| 74 | + self.create_meta_callback = create_meta_callback |
| 75 | + |
| 76 | + if isinstance(pyfunc, PyFuncSpec): |
| 77 | + self.pyfunc_name = pyfunc.class_name |
| 78 | + pyfunc_file = pyfunc.filepath |
| 79 | + elif issubclass(pyfunc, PyFunc): |
| 80 | + self.pyfunc_name = pyfunc.__name__ |
| 81 | + pyfunc_file = inspect.getfile(pyfunc) |
| 82 | + else: |
| 83 | + raise TypeError( |
| 84 | + "Pyfunc must be a subclass of PyFunc or an instance of PyFuncSpec" |
| 85 | + ) |
| 86 | + |
| 87 | + with open(pyfunc_file) as f: |
| 88 | + self.pyfunc_content = f.read() |
| 89 | + |
| 90 | + def add_input(self, input_spec: NDJSON | JSON) -> None: |
| 91 | + if not isinstance(input_spec, (NDJSON, JSON)): |
| 92 | + raise TypeError("input_spec must be NDJSON or JSON instance") |
| 93 | + if input_spec.name in [x.name for x in self._inputs]: |
| 94 | + raise ValueError(f"input with name {input_spec.name} already exists") |
| 95 | + if ( |
| 96 | + input_spec.dtype.startswith("ext::") |
| 97 | + and input_spec.dtype not in self._extra_dtypes |
| 98 | + ): |
| 99 | + raise ValueError(f"extra dtype with name {input_spec.dtype} not defined") |
| 100 | + self._inputs.append(input_spec) |
| 101 | + |
| 102 | + def add_output(self, output_spec: NDJSON | JSON) -> None: |
| 103 | + if not isinstance(output_spec, (NDJSON, JSON)): |
| 104 | + raise TypeError("output_spec must be NDJSON or JSON instance") |
| 105 | + self._outputs.append(output_spec) |
| 106 | + |
| 107 | + def add_dynamic_attribute(self, dynamic_attribute: Var) -> None: |
| 108 | + if not isinstance(dynamic_attribute, Var): |
| 109 | + raise TypeError("dynamic_attribute must be Var instance") |
| 110 | + if dynamic_attribute.name in [x.name for x in self._dynamic_attributes]: |
| 111 | + raise ValueError( |
| 112 | + f"dynamic_attribute with name {dynamic_attribute.name} already exists" |
| 113 | + ) |
| 114 | + self._dynamic_attributes.append(dynamic_attribute) |
| 115 | + |
| 116 | + def add_env_var(self, env_var: Var) -> None: |
| 117 | + if not isinstance(env_var, Var): |
| 118 | + raise TypeError("env_var must be Var instance") |
| 119 | + if env_var.name in [x.name for x in self._env_vars]: |
| 120 | + raise ValueError(f"env_var with name {env_var.name} already exists") |
| 121 | + self._env_vars.append(env_var) |
| 122 | + |
| 123 | + def set_extra_values(self, values: dict) -> None: |
| 124 | + self._extra_values = values.copy() |
| 125 | + |
| 126 | + def define_dtype(self, name: str, dtype: Type[PydanticBaseModel]) -> None: |
| 127 | + if not name.startswith("ext::"): |
| 128 | + raise ValueError("dtype name must start with 'ext::'") |
| 129 | + self._extra_dtypes[name] = dtype.model_json_schema() |
| 130 | + |
| 131 | + def set_producer_info( |
| 132 | + self, name: str, version: str, tags: list[str] | None = None |
| 133 | + ) -> None: |
| 134 | + self._producer_name = name |
| 135 | + self._producer_version = version |
| 136 | + self._producer_tags = tags or [] |
| 137 | + |
| 138 | + def add_module(self, module_path: str) -> None: |
| 139 | + module_name = module_path.split("/")[-1] |
| 140 | + if module_name in [x.split("/")[-1] for x in self._extra_modules]: |
| 141 | + raise ValueError(f"module with name {module_name} already exists") |
| 142 | + self._extra_modules.append(module_path) |
| 143 | + |
| 144 | + def add_file(self, file_path: str, target_path: str) -> None: |
| 145 | + if target_path in [x[1] for x in self._extra_files]: |
| 146 | + raise ValueError(f"file with name {target_path} already exists") |
| 147 | + self._extra_files.append((file_path, target_path)) |
| 148 | + |
| 149 | + def save(self, path: str) -> None: |
| 150 | + manifest = Manifest( |
| 151 | + variant="pyfunc", |
| 152 | + producer_name=self._producer_name, |
| 153 | + producer_version=self._producer_version, |
| 154 | + producer_tags=self._producer_tags, |
| 155 | + inputs=self._inputs, |
| 156 | + outputs=self._outputs, |
| 157 | + dynamic_attributes=self._dynamic_attributes, |
| 158 | + env_vars=self._env_vars, |
| 159 | + name=self._name, |
| 160 | + version=self._version, |
| 161 | + description=self._description, |
| 162 | + ) |
| 163 | + |
| 164 | + f = File(path) |
| 165 | + f.create_file("manifest.json", asjson(manifest)) |
| 166 | + f.create_file("dtypes.json", json.dumps(self._extra_dtypes)) |
| 167 | + f.create_file("env.json", json.dumps(self._make_env())) |
| 168 | + f.create_file("variant_artifacts/__pyfunc__.py", self.pyfunc_content) |
| 169 | + f.create_file( |
| 170 | + "variant_config.json", |
| 171 | + asjson( |
| 172 | + PyFuncVariant( |
| 173 | + pyfunc_classname=self.pyfunc_name, extra_values=self._extra_values |
| 174 | + ) |
| 175 | + ), |
| 176 | + ) |
| 177 | + f.create_file("ops.json", "[]") |
| 178 | + f.make_artifacts_folders() |
| 179 | + if self.create_meta_callback: |
| 180 | + self.create_meta_callback(f) |
| 181 | + else: |
| 182 | + f.create_file("meta.json", "[]") |
| 183 | + |
| 184 | + for module in self._extra_modules: |
| 185 | + f.copy( |
| 186 | + module, |
| 187 | + f"variant_artifacts/extra_modules/{module.split('/')[-1]}", |
| 188 | + should_exclude_pycache=True, |
| 189 | + ) |
| 190 | + |
| 191 | + for file_path, target_path in self._extra_files: |
| 192 | + f.copy(file_path, f"variant_artifacts/extra_files/{target_path}") |
| 193 | + |
| 194 | + try: |
| 195 | + pass |
| 196 | + except Exception as e: |
| 197 | + raise e |
| 198 | + finally: |
| 199 | + f.close() |
| 200 | + |
| 201 | + def add_default_build_dependencies(self) -> None: |
| 202 | + import subprocess |
| 203 | + import re |
| 204 | + |
| 205 | + def get_version( |
| 206 | + command: str, idx: int = 0, regex: str = r"^\d+(\.\d+)*([a-zA-Z]+\d*)?$" |
| 207 | + ) -> str: |
| 208 | + try: |
| 209 | + result = subprocess.run( |
| 210 | + command, |
| 211 | + shell=True, |
| 212 | + check=True, |
| 213 | + stdout=subprocess.PIPE, |
| 214 | + stderr=subprocess.PIPE, |
| 215 | + text=True, |
| 216 | + ) |
| 217 | + output = result.stdout.strip() |
| 218 | + |
| 219 | + version_parts = output.split() |
| 220 | + version = version_parts[idx] |
| 221 | + if not version or not re.match(regex, version): |
| 222 | + raise ValueError(f"Invalid version format: ```{version}```") |
| 223 | + return version |
| 224 | + except (subprocess.CalledProcessError, IndexError, ValueError) as e: |
| 225 | + print(f"Error retrieving version for '{command}': {e}") |
| 226 | + return "unknown" |
| 227 | + |
| 228 | + try: |
| 229 | + pip_version = get_version("pip --version", 1) |
| 230 | + self.add_build_dependency(f"pip=={pip_version}") |
| 231 | + except Exception as e: |
| 232 | + print(f"Error adding pip version to build dependencies: {e}") |
| 233 | + |
| 234 | + try: |
| 235 | + setuptools_version = get_version( |
| 236 | + "pip show setuptools | grep Version | awk '{print $2}'" |
| 237 | + ) |
| 238 | + self.add_build_dependency(f"setuptools=={setuptools_version}") |
| 239 | + except Exception as e: |
| 240 | + print(f"Error adding setuptools version to build dependencies: {e}") |
| 241 | + |
| 242 | + try: |
| 243 | + wheel_version = get_version( |
| 244 | + "pip show wheel | grep Version | awk '{print $2}'" |
| 245 | + ) |
| 246 | + self.add_build_dependency(f"wheel=={wheel_version}") |
| 247 | + except Exception as e: |
| 248 | + print(f"Error adding wheel version to build dependencies: {e}") |
| 249 | + |
| 250 | + def add_build_dependency(self, dep: str) -> None: |
| 251 | + self._build_dependencies.append(dep) |
| 252 | + |
| 253 | + def add_runtime_dependency(self, dep: str) -> None: |
| 254 | + # TODO add conditions for runtime dependencies |
| 255 | + self._rt_dependencies.append(dep) |
| 256 | + |
| 257 | + def add_fnnx_runtime_dependency(self, core: bool = False) -> None: |
| 258 | + import fnnx |
| 259 | + |
| 260 | + fnnx_version = fnnx.__version__ |
| 261 | + dependency_name = "fnnx[core]" if core else "fnnx" |
| 262 | + self.add_runtime_dependency(f"{dependency_name}=={fnnx_version}") |
| 263 | + |
| 264 | + def _make_env(self): |
| 265 | + return { |
| 266 | + "python3::conda_pip": asdict( |
| 267 | + Python3_CondaPip( |
| 268 | + python_version=PYTHON_VERSION, |
| 269 | + build_dependencies=self._build_dependencies, |
| 270 | + dependencies=[PipDependency(package=p) for p in self._rt_dependencies], |
| 271 | + ) |
| 272 | + ) |
| 273 | + } |
| 274 | + |
| 275 | + |
| 276 | +class File: |
| 277 | + def __init__(self, path): |
| 278 | + self.tar = tarfile.open(path, "w") |
| 279 | + |
| 280 | + def make_artifacts_folders(self): |
| 281 | + self.create_file("meta_artifacts/.keep", ".keep") |
| 282 | + self.create_file("ops_artifacts/.keep", ".keep") |
| 283 | + self.create_file("variant_artifacts/extra_modules/.keep", ".keep") |
| 284 | + self.create_file("variant_artifacts/extra_files/.keep", ".keep") |
| 285 | + |
| 286 | + def create_file(self, path, content: str): |
| 287 | + data = io.BytesIO(content.encode()) |
| 288 | + info = tarfile.TarInfo(path) |
| 289 | + info.size = len(data.getvalue()) |
| 290 | + self.tar.addfile(info, data) |
| 291 | + |
| 292 | + def copy(self, src, dst, should_exclude_pycache: bool = False): |
| 293 | + def exclude_pycache(tarinfo): |
| 294 | + if "__pycache__" in tarinfo.name and should_exclude_pycache: |
| 295 | + return None |
| 296 | + return tarinfo |
| 297 | + |
| 298 | + self.tar.add(src, dst, filter=exclude_pycache) |
| 299 | + |
| 300 | + def close(self): |
| 301 | + self.tar.close() |
0 commit comments