Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

import os
import uuid
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
Expand Down Expand Up @@ -200,14 +201,19 @@ def _download_file(self, document: Document) -> Document | None:

file_path = self.file_root_path / Path(file_name)

if file_path.is_file():
# set access and modification time to now without redownloading the file
file_path.touch()

else:
# if the file exists, avoid downloading it and just update the timestamp
try:
os.utime(file_path, None)
except FileNotFoundError:
s3_key = self.s3_key_generation_function(document) if self.s3_key_generation_function else file_name
# we know that _storage is not None after warm_up() is called, but mypy does not know that
self._storage.download(key=s3_key, local_file_path=file_path) # type: ignore[union-attr]
# download to a temp path to prevent other downloaders running concurrently to see a partially-written file
tmp_path = file_path.with_name(f"{file_path.name}.tmp-{uuid.uuid4().hex}")
try:
# we know that _storage is not None after warm_up() is called, but mypy does not know that
self._storage.download(key=s3_key, local_file_path=tmp_path) # type: ignore[union-attr]
os.replace(tmp_path, file_path)
finally:
tmp_path.unlink(missing_ok=True)

document.meta["file_path"] = str(file_path)
return document
Expand Down
23 changes: 23 additions & 0 deletions integrations/amazon_bedrock/tests/test_s3_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,29 @@ def test_run_returns_empty_when_all_filtered(self, tmp_path, mock_s3_storage, mo
assert out["documents"] == []
mock_s3_storage.download.assert_not_called()

def test_download_writes_to_temp_path_then_renames(self, tmp_path, mock_boto3_session):
d = S3Downloader(file_root_path=str(tmp_path))

final_path = tmp_path / "test.pdf"
captured_paths = []

def fake_download(key, local_file_path: Path):
captured_paths.append(Path(local_file_path))
assert not final_path.exists(), "final path must not exist while download is in progress"
Path(local_file_path).write_bytes(b"complete content")

mock_storage = MagicMock(spec=S3Storage)
mock_storage.download.side_effect = fake_download
d._storage = mock_storage

d.run(documents=[Document(meta={"file_name": "test.pdf"})])

assert len(captured_paths) == 1
assert captured_paths[0] != final_path
assert captured_paths[0].name.startswith("test.pdf.tmp-")
assert final_path.exists()
assert final_path.read_bytes() == b"complete content"

def test_cleanup_cache_evicts_old_files(self, tmp_path, mock_s3_storage, mock_boto3_session):
d = S3Downloader(file_root_path=str(tmp_path), max_cache_size=1)
d._storage = mock_s3_storage
Expand Down
Loading