Skip to content

Commit c54b64a

Browse files
authored
fix: use FileLock with atomic rename for safe concurrent model downloads (#396)
When KEDA scales up multiple ExApp pods sharing storage, concurrent downloads of the same model can corrupt files. Use OS-level FileLock (auto-releases on process death) with temp file + atomic `os.replace()` to ensure the final file is always complete or absent, never partial. Unlike `SoftFileLock`, `FileLock` cannot leave stale locks after pod SIGKILL. No lock added for snapshot_download - huggingface_hub handles this internally. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Model downloads now use file locking and atomic writes to prevent corruption during concurrent operations; lock timeouts surface as clear fetch errors. * ETag-based conditional checks skip re-downloading unchanged files. * **Tests** * Added comprehensive unit tests for download behavior, error handling, cleanup, progress reporting, and concurrency. * **Chores** * Updated tooling to include additional test paths and lint rules; added a runtime dependency for file locking. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Oleksander Piskun <oleksandr2088@icloud.com>
1 parent b8aeb03 commit c54b64a

4 files changed

Lines changed: 279 additions & 39 deletions

File tree

.pre-commit-config.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ repos:
2020
nc_py_api/|
2121
benchmarks/|
2222
examples/|
23-
tests/
23+
tests/|
24+
tests_unit/
2425
)
2526
2627
- repo: https://github.com/psf/black
@@ -32,7 +33,8 @@ repos:
3233
nc_py_api/|
3334
benchmarks/|
3435
examples/|
35-
tests/
36+
tests/|
37+
tests_unit/
3638
)
3739
3840
- repo: https://github.com/tox-dev/pyproject-fmt

nc_py_api/ex_app/integration_fastapi.py

Lines changed: 53 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
status,
2121
)
2222
from fastapi.responses import JSONResponse, PlainTextResponse
23+
from filelock import FileLock
24+
from filelock import Timeout as FileLockTimeout
2325
from starlette.requests import HTTPConnection, Request
2426
from starlette.types import ASGIApp, Receive, Scope, Send
2527

@@ -158,7 +160,7 @@ def fetch_models_task(nc: NextcloudApp, models: dict[str, dict], progress_init_s
158160
"""Use for cases when you want to define custom `/init` but still need to easy download models.
159161
160162
:param nc: NextcloudApp instance.
161-
:param models_to_fetch: Dictionary describing which models should be downloaded of the form:
163+
:param models: Dictionary describing which models should be downloaded of the form:
162164
.. code-block:: python
163165
{
164166
"model_url_1": {
@@ -205,42 +207,56 @@ def __fetch_model_as_file(
205207
current_progress: int, progress_for_task: int, nc: NextcloudApp, model_path: str, download_options: dict
206208
) -> str:
207209
result_path = download_options.pop("save_path", urlparse(model_path).path.split("/")[-1])
208-
with niquests.get(model_path, stream=True) as response:
209-
if not response.ok:
210-
raise ModelFetchError(
211-
f"Downloading of '{model_path}' failed, returned ({response.status_code}) {response.text}"
212-
)
213-
downloaded_size = 0
214-
linked_etag = ""
215-
for each_history in response.history:
216-
linked_etag = each_history.headers.get("X-Linked-ETag", "")
217-
if linked_etag:
218-
break
219-
if not linked_etag:
220-
linked_etag = response.headers.get("X-Linked-ETag", response.headers.get("ETag", ""))
221-
total_size = int(response.headers.get("Content-Length"))
222-
try:
223-
existing_size = os.path.getsize(result_path)
224-
except OSError:
225-
existing_size = 0
226-
if linked_etag and total_size == existing_size:
227-
with builtins.open(result_path, "rb") as file:
228-
sha256_hash = hashlib.sha256()
229-
for byte_block in iter(lambda: file.read(4096), b""):
230-
sha256_hash.update(byte_block)
231-
if f'"{sha256_hash.hexdigest()}"' == linked_etag:
232-
nc.set_init_status(min(current_progress + progress_for_task, 99))
233-
return result_path
234-
235-
with builtins.open(result_path, "wb") as file:
236-
last_progress = current_progress
237-
for chunk in response.iter_raw(-1):
238-
downloaded_size += file.write(chunk)
239-
if total_size:
240-
new_progress = min(current_progress + int(progress_for_task * downloaded_size / total_size), 99)
241-
if new_progress != last_progress:
242-
nc.set_init_status(new_progress)
243-
last_progress = new_progress
210+
tmp_path = result_path + ".tmp"
211+
try:
212+
with FileLock(result_path + ".lock", timeout=7200), niquests.get(model_path, stream=True) as response:
213+
if not response.ok:
214+
raise ModelFetchError(
215+
f"Downloading of '{model_path}' failed, returned ({response.status_code}) {response.text}"
216+
)
217+
downloaded_size = 0
218+
linked_etag = ""
219+
for redirect_resp in response.history:
220+
linked_etag = redirect_resp.headers.get("X-Linked-ETag", "")
221+
if linked_etag:
222+
break
223+
if not linked_etag:
224+
linked_etag = response.headers.get("X-Linked-ETag", response.headers.get("ETag", ""))
225+
total_size = int(response.headers.get("Content-Length", 0))
226+
try:
227+
existing_size = os.path.getsize(result_path)
228+
except OSError:
229+
existing_size = 0
230+
if linked_etag and total_size and total_size == existing_size:
231+
with builtins.open(result_path, "rb") as file:
232+
sha256_hash = hashlib.sha256()
233+
for byte_block in iter(lambda: file.read(4096), b""):
234+
sha256_hash.update(byte_block)
235+
if f'"{sha256_hash.hexdigest()}"' == linked_etag:
236+
nc.set_init_status(min(current_progress + progress_for_task, 99))
237+
return result_path
238+
239+
try:
240+
with builtins.open(tmp_path, "wb") as file:
241+
last_progress = current_progress
242+
for chunk in response.iter_raw(-1):
243+
downloaded_size += file.write(chunk)
244+
if total_size:
245+
new_progress = min(
246+
current_progress + int(progress_for_task * downloaded_size / total_size), 99
247+
)
248+
if new_progress != last_progress:
249+
nc.set_init_status(new_progress)
250+
last_progress = new_progress
251+
os.replace(tmp_path, result_path)
252+
except BaseException:
253+
if os.path.exists(tmp_path):
254+
os.remove(tmp_path)
255+
raise
256+
except FileLockTimeout as exc:
257+
raise ModelFetchError(
258+
f"Timed out waiting for lock on '{result_path}' after 7200s — another process may be stuck downloading"
259+
) from exc
244260

245261
return result_path
246262

pyproject.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ dynamic = [
4646
]
4747
dependencies = [
4848
"fastapi>=0.109.2",
49+
"filelock>=3.20.3,<4",
4950
"niquests>=3,<4",
5051
"pydantic>=2.1.1",
5152
"python-dotenv>=1",
@@ -148,6 +149,12 @@ lint.extend-per-file-ignores."tests/**/*.py" = [
148149
"S",
149150
"UP",
150151
]
152+
lint.extend-per-file-ignores."tests_unit/**/*.py" = [
153+
"D",
154+
"E402",
155+
"S",
156+
"UP",
157+
]
151158
lint.mccabe.max-complexity = 16
152159

153160
[tool.isort]
@@ -198,6 +205,7 @@ messages_control.disable = [
198205
minversion = "6.0"
199206
testpaths = [
200207
"tests",
208+
"tests_unit",
201209
]
202210
filterwarnings = [
203211
"ignore::DeprecationWarning",
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
"""Tests for model file download with FileLock and atomic rename."""
2+
3+
import hashlib
4+
import os
5+
from threading import Thread
6+
from unittest import mock
7+
8+
import pytest
9+
from filelock import FileLock
10+
from filelock import Timeout as FileLockTimeout
11+
12+
from nc_py_api._exceptions import ModelFetchError
13+
from nc_py_api.ex_app.integration_fastapi import fetch_models_task
14+
15+
16+
class FakeResponse:
17+
"""Mock HTTP response for niquests.get() with streaming support."""
18+
19+
def __init__(self, content: bytes, etag: str = "", status_code: int = 200, ok: bool = True):
20+
self.content = content
21+
self.status_code = status_code
22+
self.ok = ok
23+
self.text = "" if ok else "Not Found"
24+
self.history = []
25+
sha = hashlib.sha256(content).hexdigest()
26+
self.headers = {
27+
"Content-Length": str(len(content)),
28+
"ETag": etag or f'"{sha}"',
29+
}
30+
31+
def iter_raw(self, _chunk_size):
32+
yield self.content
33+
34+
def __enter__(self):
35+
return self
36+
37+
def __exit__(self, *args):
38+
pass
39+
40+
41+
def _mock_nc():
42+
nc = mock.MagicMock()
43+
nc.set_init_status = mock.MagicMock()
44+
return nc
45+
46+
47+
class TestFetchModelAsFile:
48+
"""Tests for __fetch_model_as_file via fetch_models_task."""
49+
50+
def test_download_creates_file(self, tmp_path):
51+
content = b"model-data-abc"
52+
save_path = str(tmp_path / "model.bin")
53+
54+
with mock.patch("nc_py_api.ex_app.integration_fastapi.niquests.get", return_value=FakeResponse(content)):
55+
fetch_models_task(_mock_nc(), {"https://example.com/model.bin": {"save_path": save_path}}, 0)
56+
57+
assert os.path.isfile(save_path)
58+
with open(save_path, "rb") as f:
59+
assert f.read() == content
60+
61+
def test_no_tmp_file_remains_after_success(self, tmp_path):
62+
save_path = str(tmp_path / "model.bin")
63+
64+
with mock.patch("nc_py_api.ex_app.integration_fastapi.niquests.get", return_value=FakeResponse(b"data")):
65+
fetch_models_task(_mock_nc(), {"https://example.com/m.bin": {"save_path": save_path}}, 0)
66+
67+
assert not os.path.exists(save_path + ".tmp")
68+
69+
def test_lock_file_released_after_download(self, tmp_path):
70+
save_path = str(tmp_path / "model.bin")
71+
72+
with mock.patch("nc_py_api.ex_app.integration_fastapi.niquests.get", return_value=FakeResponse(b"data")):
73+
fetch_models_task(_mock_nc(), {"https://example.com/m.bin": {"save_path": save_path}}, 0)
74+
75+
lock_path = save_path + ".lock"
76+
# Lock file may or may not exist (FileLock implementation detail),
77+
# but it must not be held — acquiring it should succeed immediately.
78+
lock = FileLock(lock_path, timeout=1)
79+
lock.acquire()
80+
lock.release()
81+
82+
def test_skips_download_when_file_matches_etag(self, tmp_path):
83+
content = b"existing-model-data"
84+
sha = hashlib.sha256(content).hexdigest()
85+
etag = f'"{sha}"'
86+
save_path = str(tmp_path / "model.bin")
87+
with open(save_path, "wb") as f:
88+
f.write(content)
89+
90+
call_count = {"iter_raw": 0}
91+
original_iter_raw = FakeResponse.iter_raw
92+
93+
def tracking_iter_raw(self, chunk_size):
94+
call_count["iter_raw"] += 1
95+
yield from original_iter_raw(self, chunk_size)
96+
97+
resp = FakeResponse(content, etag=etag)
98+
resp.iter_raw = lambda cs: tracking_iter_raw(resp, cs)
99+
100+
with mock.patch("nc_py_api.ex_app.integration_fastapi.niquests.get", return_value=resp):
101+
fetch_models_task(_mock_nc(), {"https://example.com/model.bin": {"save_path": save_path}}, 0)
102+
103+
assert call_count["iter_raw"] == 0
104+
with open(save_path, "rb") as f:
105+
assert f.read() == content
106+
107+
def test_tmp_file_cleaned_up_on_download_error(self, tmp_path):
108+
save_path = str(tmp_path / "model.bin")
109+
110+
def failing_iter_raw(_chunk_size):
111+
yield b"partial"
112+
raise ConnectionError("network down")
113+
114+
resp = FakeResponse(b"full-content")
115+
resp.iter_raw = failing_iter_raw
116+
117+
with (
118+
mock.patch("nc_py_api.ex_app.integration_fastapi.niquests.get", return_value=resp),
119+
pytest.raises(ModelFetchError),
120+
):
121+
fetch_models_task(_mock_nc(), {"https://example.com/m.bin": {"save_path": save_path}}, 0)
122+
123+
assert not os.path.exists(save_path + ".tmp")
124+
assert not os.path.exists(save_path)
125+
126+
def test_original_file_untouched_on_download_error(self, tmp_path):
127+
save_path = str(tmp_path / "model.bin")
128+
with open(save_path, "wb") as f:
129+
f.write(b"original-good-data")
130+
131+
def failing_iter_raw(_chunk_size):
132+
yield b"partial"
133+
raise ConnectionError("network down")
134+
135+
resp = FakeResponse(b"new-content", etag='"different-etag"')
136+
resp.iter_raw = failing_iter_raw
137+
138+
with (
139+
mock.patch("nc_py_api.ex_app.integration_fastapi.niquests.get", return_value=resp),
140+
pytest.raises(ModelFetchError),
141+
):
142+
fetch_models_task(_mock_nc(), {"https://example.com/m.bin": {"save_path": save_path}}, 0)
143+
144+
with open(save_path, "rb") as f:
145+
assert f.read() == b"original-good-data"
146+
147+
def test_http_error_raises_model_fetch_error(self, tmp_path):
148+
save_path = str(tmp_path / "model.bin")
149+
resp = FakeResponse(b"", status_code=404, ok=False)
150+
151+
with (
152+
mock.patch("nc_py_api.ex_app.integration_fastapi.niquests.get", return_value=resp),
153+
pytest.raises(ModelFetchError),
154+
):
155+
fetch_models_task(_mock_nc(), {"https://example.com/m.bin": {"save_path": save_path}}, 0)
156+
157+
def test_concurrent_downloads_do_not_corrupt(self, tmp_path):
158+
save_path = str(tmp_path / "model.bin")
159+
errors = []
160+
161+
def download():
162+
try:
163+
fetch_models_task(_mock_nc(), {"https://example.com/m.bin": {"save_path": save_path}}, 0)
164+
except Exception as e: # noqa pylint: disable=broad-exception-caught
165+
errors.append(e)
166+
167+
# Patch once around both threads to avoid mock.patch context manager
168+
# race: independent per-thread patches can restore the original
169+
# function while the other thread still needs the mock.
170+
responses = iter([FakeResponse(b"A" * 10000), FakeResponse(b"B" * 10000)])
171+
172+
def mock_get(_url, **_kwargs):
173+
return next(responses)
174+
175+
with mock.patch("nc_py_api.ex_app.integration_fastapi.niquests.get", side_effect=mock_get):
176+
t1 = Thread(target=download)
177+
t2 = Thread(target=download)
178+
t1.start()
179+
t2.start()
180+
t1.join(timeout=60)
181+
t2.join(timeout=60)
182+
183+
assert not errors, f"Threads raised errors: {errors}"
184+
assert os.path.isfile(save_path)
185+
with open(save_path, "rb") as f:
186+
data = f.read()
187+
# File must be entirely one content or the other — never mixed
188+
assert data in (b"A" * 10000, b"B" * 10000)
189+
190+
def test_filelock_timeout_raises_model_fetch_error(self, tmp_path):
191+
save_path = str(tmp_path / "model.bin")
192+
lock = FileLock(save_path + ".lock")
193+
nc = _mock_nc()
194+
195+
with (
196+
mock.patch("nc_py_api.ex_app.integration_fastapi.FileLock", side_effect=FileLockTimeout(lock)),
197+
pytest.raises(ModelFetchError),
198+
):
199+
fetch_models_task(nc, {"https://example.com/m.bin": {"save_path": save_path}}, 0)
200+
201+
status_msg = nc.set_init_status.call_args_list[-1][0][1]
202+
assert "Timed out waiting for lock" in status_msg
203+
204+
def test_progress_updates_sent(self, tmp_path):
205+
save_path = str(tmp_path / "model.bin")
206+
nc = _mock_nc()
207+
208+
with mock.patch("nc_py_api.ex_app.integration_fastapi.niquests.get", return_value=FakeResponse(b"data")):
209+
fetch_models_task(nc, {"https://example.com/m.bin": {"save_path": save_path}}, 0)
210+
211+
# set_init_status should be called at least for completion (100)
212+
assert nc.set_init_status.called
213+
# Last call should be 100 (completion)
214+
assert nc.set_init_status.call_args_list[-1] == mock.call(100)

0 commit comments

Comments
 (0)