Skip to content

Commit e21f0ce

Browse files
committed
chore(ai): azure/s3 store prefix/no-prefix fixes
1 parent 1995bce commit e21f0ce

2 files changed

Lines changed: 22 additions & 6 deletions

File tree

packages/ai/src/ai/account/store_providers/azure/azure.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -331,14 +331,17 @@ def _match(name):
331331
if name_pattern == '..':
332332
raise StorageError(f'Path traversal detected: {name_pattern}')
333333

334+
def _part_len(path):
335+
return path.count('/') + 1 if path else 0
336+
334337
client = self._get_client()
335338
container_client = client.get_container_client(self._container)
336339

337340
files = []
338341
seen_dirs = set() if recursive and include_dirs else None
339342

340343
blob_prefix = self._get_blob_name(prefix) if prefix else self._prefix
341-
prefix_part_len = blob_prefix.count('/') - self._prefix.count('/')
344+
prefix_part_len = _part_len(blob_prefix) - _part_len(self._prefix)
342345

343346
blob_list = (
344347
await asyncio.to_thread(container_client.list_blobs, name_starts_with=blob_prefix)
@@ -552,5 +555,10 @@ def _get_blob_name(self, path: str) -> str:
552555
# Ensure the resolved name stays within the prefix
553556
if not full_name.startswith(self._prefix + '/') and full_name != self._prefix:
554557
raise StorageError(f'Path traversal detected: {path}')
555-
return full_name
556-
return posixpath.normpath(path)
558+
else:
559+
full_name = posixpath.normpath(path)
560+
# This isn't a path traversal case, but let's still raise
561+
# an error to ensure consistency across all providers
562+
if full_name.startswith('../') or full_name == '..':
563+
raise StorageError(f'Path traversal detected: {path}')
564+
return full_name

packages/ai/src/ai/account/store_providers/s3/s3.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -328,9 +328,12 @@ async def list_entries(
328328
if name_pattern == '..':
329329
raise StorageError(f'Path traversal detected: {name_pattern}')
330330

331+
def _part_len(path):
332+
return path.count('/') + 1 if path else 0
333+
331334
client = self._get_client()
332335
key_prefix = self._get_key(prefix) if prefix else self._prefix
333-
prefix_part_len = key_prefix.count('/') - self._prefix.count('/')
336+
prefix_part_len = _part_len(key_prefix) - _part_len(self._prefix)
334337

335338
def _match(name):
336339
return not name_pattern or fnmatch(name, name_pattern)
@@ -591,5 +594,10 @@ def _get_key(self, path: str) -> str:
591594
# Ensure the resolved key stays within the prefix
592595
if not full_key.startswith(self._prefix + '/') and full_key != self._prefix:
593596
raise StorageError(f'Path traversal detected: {path}')
594-
return full_key
595-
return posixpath.normpath(path)
597+
else:
598+
full_key = posixpath.normpath(path)
599+
# This isn't a path traversal case, but let's still raise
600+
# an error to ensure consistency across all providers
601+
if full_key.startswith('../') or full_key == '..':
602+
raise StorageError(f'Path traversal detected: {path}')
603+
return full_key

0 commit comments

Comments
 (0)