Skip to content

Commit 1c2d50b

Browse files
authored
Cache flatc binary and schema extraction to fix 3x fbpkg export slowdown (#19104)
Differential Revision: D102214303 Pull Request resolved: #19104
1 parent b384173 commit 1c2d50b

2 files changed

Lines changed: 106 additions & 27 deletions

File tree

backends/xnnpack/serialization/xnnpack_graph_serialize.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import logging
1111
import os
1212
import tempfile
13+
import threading
1314
from dataclasses import dataclass, fields, is_dataclass
1415
from typing import ClassVar, Literal, Optional
1516

@@ -300,6 +301,29 @@ def pretty_print_xnngraph(xnnpack_graph_json: str, filename: Optional[str] = Non
300301
_delegate_instance_id = 0
301302

302303

304+
_cached_schema_bytes: Optional[bytes] = None
305+
# Lock protecting _cached_schema_bytes initialization. The race itself would
306+
# be benign (read_bytes() is idempotent), but a lock keeps this consistent
307+
# with the pattern used for caching the flatc binary in _flatbuffer.py.
308+
_schema_bytes_lock: threading.Lock = threading.Lock()
309+
310+
311+
def _get_schema_bytes() -> bytes:
312+
"""Returns the schema.fbs bytes, caching the result across calls."""
313+
global _cached_schema_bytes
314+
# Double-checked locking: fast path avoids the lock once cached.
315+
if _cached_schema_bytes is not None:
316+
return _cached_schema_bytes
317+
with _schema_bytes_lock:
318+
if _cached_schema_bytes is None:
319+
_cached_schema_bytes = (
320+
_resources.files(serialization_package)
321+
.joinpath("schema.fbs")
322+
.read_bytes()
323+
)
324+
return _cached_schema_bytes
325+
326+
303327
def convert_to_flatbuffer(xnnpack_graph: XNNGraph) -> bytes:
304328
global _delegate_instance_id
305329
sanity_check_xnngraph_dataclass(xnnpack_graph)
@@ -316,11 +340,7 @@ def convert_to_flatbuffer(xnnpack_graph: XNNGraph) -> bytes:
316340
with tempfile.TemporaryDirectory() as d:
317341
schema_path = os.path.join(d, "schema.fbs")
318342
with open(schema_path, "wb") as schema_file:
319-
schema_file.write(
320-
_resources.files(serialization_package)
321-
.joinpath("schema.fbs")
322-
.read_bytes()
323-
)
343+
schema_file.write(_get_schema_bytes())
324344
json_path = os.path.join(d, "schema.json")
325345
with open(json_path, "wb") as json_file:
326346
json_file.write(xnnpack_graph_json.encode("ascii"))

exir/_serialize/_flatbuffer.py

Lines changed: 81 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@
77

88
# pyre-strict
99

10+
import atexit
11+
import contextlib
1012
import importlib.resources
1113
import os
1214
import re
1315
import shutil
1416
import stat
1517
import subprocess
16-
1718
import tempfile
19+
import threading
1820

1921
from dataclasses import dataclass
2022
from typing import Callable, Dict, List, Optional, Sequence
@@ -240,20 +242,41 @@ class _FlatbufferResult:
240242
# Name of an optional resource containing the `flatc` executable.
241243
_FLATC_RESOURCE_NAME: str = "flatbuffers-flatc"
242244

243-
244-
def _run_flatc(args: Sequence[str]) -> None:
245-
"""Runs the `flatc` command with the provided args.
246-
247-
If a resource matching _FLATC_RESOURCE_NAME exists, uses that executable.
248-
Otherwise, expects the `flatc` tool to be available on the system path.
249-
"""
250-
flatc_resource = importlib.resources.files(__package__).joinpath(
251-
_FLATC_RESOURCE_NAME
252-
)
253-
if flatc_resource.is_file():
254-
# Use the provided flatc binary.
255-
with importlib.resources.as_file(flatc_resource) as flatc_path:
256-
# Ensure the binary has execute permissions (needed for PAR files)
245+
# Cached flatc binary path. In PAR files, importlib.resources.as_file()
246+
# extracts the binary to a temp file on each call. With 200+ XNNPACK
247+
# partitions this adds ~30 min of overhead. Caching avoids re-extraction.
248+
# The ExitStack is registered with atexit so the extracted temp file is
249+
# cleaned up on normal process exit.
250+
#
251+
# Fork safety: the child inherits the parent's atexit registry and cached
252+
# path. Without _reset_flatc_cache_after_fork, the child's atexit would
253+
# run the inherited handler and unlink the parent's temp file. The
254+
# before/after_in_parent callbacks hold _flatc_lock across fork so the
255+
# child never inherits a half-initialized cache.
256+
_flatc_cached_path: Optional[str] = None
257+
_flatc_exit_stack: Optional[contextlib.ExitStack] = None
258+
_flatc_lock: threading.Lock = threading.Lock()
259+
260+
261+
def _get_flatc_path() -> str:
262+
"""Returns the path to the flatc executable, caching the result."""
263+
global _flatc_cached_path, _flatc_exit_stack
264+
# Double-checked locking: fast path avoids the lock once cached.
265+
if _flatc_cached_path is not None:
266+
return _flatc_cached_path
267+
268+
with _flatc_lock:
269+
if _flatc_cached_path is not None:
270+
return _flatc_cached_path
271+
272+
flatc_resource = importlib.resources.files(__package__).joinpath(
273+
_FLATC_RESOURCE_NAME
274+
)
275+
if flatc_resource.is_file():
276+
exit_stack = contextlib.ExitStack()
277+
flatc_path = exit_stack.enter_context(
278+
importlib.resources.as_file(flatc_resource)
279+
)
257280
try:
258281
current_mode = flatc_path.stat().st_mode
259282
if not (current_mode & stat.S_IXUSR):
@@ -262,13 +285,49 @@ def _run_flatc(args: Sequence[str]) -> None:
262285
)
263286
except OSError:
264287
pass
265-
subprocess.run([flatc_path] + list(args), check=True)
266-
else:
267-
# Expect the `flatc` tool to be on the system path or set as an env var.
268-
flatc_path = os.getenv("FLATC_EXECUTABLE")
269-
if not flatc_path:
270-
flatc_path = "flatc"
271-
subprocess.run([flatc_path] + list(args), check=True)
288+
_flatc_exit_stack = exit_stack
289+
# Clean up the extracted temp file on normal process exit.
290+
atexit.register(exit_stack.close)
291+
_flatc_cached_path = str(flatc_path)
292+
else:
293+
_flatc_cached_path = os.getenv("FLATC_EXECUTABLE", "flatc")
294+
295+
return _flatc_cached_path
296+
297+
298+
def _reset_flatc_cache_after_fork() -> None:
299+
"""Reset the flatc cache in the child after fork.
300+
301+
Unregister the inherited atexit handler (do NOT call .close() — the
302+
parent still owns the file), clear the cached state so the child
303+
re-extracts lazily, and replace the lock (the inherited one is held
304+
by the `before` fork callback but the acquiring thread no longer
305+
exists in the child).
306+
"""
307+
global _flatc_cached_path, _flatc_exit_stack, _flatc_lock
308+
if _flatc_exit_stack is not None:
309+
atexit.unregister(_flatc_exit_stack.close)
310+
_flatc_cached_path = None
311+
_flatc_exit_stack = None
312+
_flatc_lock = threading.Lock()
313+
314+
315+
# os.register_at_fork is Unix-only; guard for Windows importability.
316+
if hasattr(os, "register_at_fork"):
317+
os.register_at_fork(
318+
before=lambda: _flatc_lock.acquire(),
319+
after_in_parent=lambda: _flatc_lock.release(),
320+
after_in_child=_reset_flatc_cache_after_fork,
321+
)
322+
323+
324+
def _run_flatc(args: Sequence[str]) -> None:
325+
"""Runs the `flatc` command with the provided args.
326+
327+
If a resource matching _FLATC_RESOURCE_NAME exists, uses that executable.
328+
Otherwise, expects the `flatc` tool to be available on the system path.
329+
"""
330+
subprocess.run([_get_flatc_path()] + list(args), check=True)
272331

273332

274333
def _flatc_compile(output_dir: str, schema_path: str, json_path: str) -> None:

0 commit comments

Comments
 (0)