Skip to content

Commit 5a85040

Browse files
Merge pull request #544 from laughingman7743/max_depth_with_dirs
2 parents 7feea83 + 681e749 commit 5a85040

3 files changed

Lines changed: 288 additions & 52 deletions

File tree

CLAUDE.md

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ The project supports different cursor implementations for various use cases:
3131
## Development Guidelines
3232

3333
### Code Style and Quality
34+
35+
#### Commands
3436
```bash
3537
# Format code (auto-fix imports and format)
3638
make fmt
@@ -51,12 +53,67 @@ make tox
5153
make docs
5254
```
5355

56+
#### Docstring Style
57+
Use Google style docstrings for all public methods and complex internal methods:
58+
59+
```python
60+
def method_name(self, param1: str, param2: Optional[int] = None) -> List[str]:
61+
"""Brief description of what the method does.
62+
63+
Longer description if needed, explaining the method's behavior,
64+
edge cases, or important details.
65+
66+
Args:
67+
param1: Description of the first parameter.
68+
param2: Description of the optional parameter.
69+
70+
Returns:
71+
Description of the return value.
72+
73+
Raises:
74+
ValueError: When invalid parameters are provided.
75+
"""
76+
```
77+
5478
### Testing Requirements
79+
80+
#### General Guidelines
5581
1. **Unit Tests**: All new features must include unit tests
5682
2. **Integration Tests**: Test actual AWS Athena interactions when modifying query execution logic
5783
3. **SQLAlchemy Compliance**: Ensure SQLAlchemy dialect tests pass when modifying dialect code
5884
4. **Mock AWS Services**: Use `moto` or similar for testing AWS interactions without real resources
5985

86+
#### Writing Tests
87+
- Place tests in `tests/pyathena/` mirroring the source structure
88+
- Use pytest fixtures for common setup (see `conftest.py`)
89+
- Test both success and error cases
90+
- For filesystem operations, test edge cases like empty results, missing files, etc.
91+
92+
Example test structure:
93+
```python
94+
def test_find_maxdepth(self, fs):
95+
"""Test find with maxdepth parameter."""
96+
# Setup test data
97+
dir_ = f"s3://{ENV.s3_staging_bucket}/test_path"
98+
fs.touch(f"{dir_}/file0.txt")
99+
fs.touch(f"{dir_}/level1/file1.txt")
100+
101+
# Test maxdepth=0
102+
result = fs.find(dir_, maxdepth=0)
103+
assert len(result) == 1
104+
assert fs._strip_protocol(f"{dir_}/file0.txt") in result
105+
106+
# Test edge cases and error conditions
107+
with pytest.raises(ValueError):
108+
fs.find("s3://", maxdepth=0)
109+
```
110+
111+
#### Test Organization
112+
- Group related tests in classes (e.g., `TestS3FileSystem`)
113+
- Use descriptive test names that explain what is being tested
114+
- Keep tests focused and independent
115+
- Clean up test data after each test when using real AWS resources
116+
60117
### Common Development Tasks
61118

62119
#### Adding a New Feature
@@ -94,6 +151,8 @@ pyathena/
94151
│ └── requirements.py # SQLAlchemy requirements
95152
96153
└── filesystem/ # S3 filesystem abstractions
154+
├── s3.py # S3FileSystem implementation (fsspec compatible)
155+
└── s3_object.py # S3 object representations
97156
```
98157

99158
### Important Implementation Details
@@ -115,6 +174,21 @@ pyathena/
115174
- Follow DB API 2.0 exception hierarchy
116175
- Provide meaningful error messages that include Athena query IDs when available
117176

177+
#### S3 FileSystem Operations
178+
- `S3FileSystem` implements fsspec's `AbstractFileSystem` interface
179+
- Key methods include `ls()`, `find()`, `get()`, `put()`, `rm()`, etc.
180+
- `find()` method supports:
181+
- `maxdepth`: Limits directory traversal depth (uses recursive approach for efficiency)
182+
- `withdirs`: Controls whether directories are included in results (default: False)
183+
- Cache management uses `(path, delimiter)` as key to handle different listing modes
184+
- Always extract reusable logic into helper methods (e.g., `_extract_parent_directories()`)
185+
186+
When implementing filesystem methods:
187+
1. **Consider s3fs compatibility** - Many users migrate from s3fs, so matching its behavior is important
188+
2. **Optimize for S3's API** - Use delimiter="/" for recursive operations to minimize API calls
189+
3. **Handle edge cases** - Empty paths, trailing slashes, bucket-only paths
190+
4. **Test with real S3** - Mock tests may not catch S3-specific behaviors
191+
118192
### Performance Considerations
119193
1. **Result Caching**: Utilize Athena's result reuse feature (engine v3) when possible
120194
2. **Batch Operations**: Support `executemany()` for bulk operations

pyathena/filesystem/s3.py

Lines changed: 157 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -250,55 +250,57 @@ def _ls_dirs(
250250
bucket, key, version_id = self.parse_path(path)
251251
if key:
252252
prefix = f"{key}/{prefix if prefix else ''}"
253-
if path not in self.dircache or refresh:
254-
files: List[S3Object] = []
255-
while True:
256-
request: Dict[Any, Any] = {
257-
"Bucket": bucket,
258-
"Prefix": prefix,
259-
"Delimiter": delimiter,
260-
}
261-
if next_token:
262-
request.update({"ContinuationToken": next_token})
263-
if max_keys:
264-
request.update({"MaxKeys": max_keys})
265-
response = self._call(
266-
self._client.list_objects_v2,
267-
**request,
268-
)
269-
files.extend(
270-
S3Object(
271-
init={
272-
"ContentLength": 0,
273-
"ContentType": None,
274-
"StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY,
275-
"ETag": None,
276-
"LastModified": None,
277-
},
278-
type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY,
279-
bucket=bucket,
280-
key=c["Prefix"][:-1].rstrip("/"),
281-
version_id=version_id,
282-
)
283-
for c in response.get("CommonPrefixes", [])
253+
254+
# Create a cache key that includes the delimiter
255+
cache_key = (path, delimiter)
256+
if cache_key in self.dircache and not refresh:
257+
return cast(List[S3Object], self.dircache[cache_key])
258+
259+
files: List[S3Object] = []
260+
while True:
261+
request: Dict[Any, Any] = {
262+
"Bucket": bucket,
263+
"Prefix": prefix,
264+
"Delimiter": delimiter,
265+
}
266+
if next_token:
267+
request.update({"ContinuationToken": next_token})
268+
if max_keys:
269+
request.update({"MaxKeys": max_keys})
270+
response = self._call(
271+
self._client.list_objects_v2,
272+
**request,
273+
)
274+
files.extend(
275+
S3Object(
276+
init={
277+
"ContentLength": 0,
278+
"ContentType": None,
279+
"StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY,
280+
"ETag": None,
281+
"LastModified": None,
282+
},
283+
type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY,
284+
bucket=bucket,
285+
key=c["Prefix"][:-1].rstrip("/"),
286+
version_id=version_id,
284287
)
285-
files.extend(
286-
S3Object(
287-
init=c,
288-
type=S3ObjectType.S3_OBJECT_TYPE_FILE,
289-
bucket=bucket,
290-
key=c["Key"],
291-
)
292-
for c in response.get("Contents", [])
288+
for c in response.get("CommonPrefixes", [])
289+
)
290+
files.extend(
291+
S3Object(
292+
init=c,
293+
type=S3ObjectType.S3_OBJECT_TYPE_FILE,
294+
bucket=bucket,
295+
key=c["Key"],
293296
)
294-
next_token = response.get("NextContinuationToken")
295-
if not next_token:
296-
break
297-
if files:
298-
self.dircache[path] = files
299-
else:
300-
cache = self.dircache[path]
301-
files = cache if isinstance(cache, list) else [cache]
297+
for c in response.get("Contents", [])
298+
)
299+
next_token = response.get("NextContinuationToken")
300+
if not next_token:
301+
break
302+
if files:
303+
self.dircache[cache_key] = files
302304
return files
303305

304306
def ls(
@@ -396,27 +398,131 @@ def info(self, path: str, **kwargs) -> S3Object:
396398
)
397399
raise FileNotFoundError(path)
398400

399-
def find(
401+
def _extract_parent_directories(
402+
self, files: List[S3Object], bucket: str, base_key: Optional[str]
403+
) -> List[S3Object]:
404+
"""Extract parent directory objects from file paths.
405+
406+
When listing files without delimiter, S3 doesn't return directory entries.
407+
This method creates directory objects by analyzing file paths.
408+
409+
Args:
410+
files: List of S3Object instances representing files.
411+
bucket: S3 bucket name.
412+
base_key: Base key path to calculate relative paths from.
413+
414+
Returns:
415+
List of S3Object instances representing directories.
416+
"""
417+
dirs = set()
418+
base_key = base_key.rstrip("/") if base_key else ""
419+
420+
for f in files:
421+
if f.key and f.type == S3ObjectType.S3_OBJECT_TYPE_FILE:
422+
# Extract directory paths from file paths
423+
f_key = f.key
424+
if base_key and f_key.startswith(base_key + "/"):
425+
relative_path = f_key[len(base_key) + 1 :]
426+
elif not base_key:
427+
relative_path = f_key
428+
else:
429+
continue
430+
431+
# Get all parent directories
432+
parts = relative_path.split("/")
433+
for i in range(1, len(parts)):
434+
if base_key:
435+
dir_path = base_key + "/" + "/".join(parts[:i])
436+
else:
437+
dir_path = "/".join(parts[:i])
438+
dirs.add(dir_path)
439+
440+
# Create S3Object instances for directories
441+
directory_objects = []
442+
for dir_path in dirs:
443+
dir_obj = S3Object(
444+
init={
445+
"ContentLength": 0,
446+
"ContentType": None,
447+
"StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY,
448+
"ETag": None,
449+
"LastModified": None,
450+
},
451+
type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY,
452+
bucket=bucket,
453+
key=dir_path,
454+
version_id=None,
455+
)
456+
directory_objects.append(dir_obj)
457+
458+
return directory_objects
459+
460+
def _find(
400461
self,
401462
path: str,
402463
maxdepth: Optional[int] = None,
403464
withdirs: Optional[bool] = None,
404-
detail: bool = False,
405465
**kwargs,
406-
) -> Union[Dict[str, S3Object], List[str]]:
407-
# TODO: Support maxdepth and withdirs
466+
) -> List[S3Object]:
408467
path = self._strip_protocol(path)
409468
if path in ["", "/"]:
410469
raise ValueError("Cannot traverse all files in S3.")
411470
bucket, key, _ = self.parse_path(path)
412471
prefix = kwargs.pop("prefix", "")
413472

473+
# When maxdepth is specified, use a recursive approach with delimiter
474+
if maxdepth is not None:
475+
result: List[S3Object] = []
476+
477+
# List files and directories at current level
478+
current_items = self._ls_dirs(path, prefix=prefix, delimiter="/")
479+
480+
for item in current_items:
481+
if item.type == S3ObjectType.S3_OBJECT_TYPE_FILE:
482+
# Add files
483+
result.append(item)
484+
elif item.type == S3ObjectType.S3_OBJECT_TYPE_DIRECTORY:
485+
# Add directory if withdirs is True
486+
if withdirs:
487+
result.append(item)
488+
489+
# Recursively explore subdirectory if depth allows
490+
if maxdepth > 0:
491+
sub_path = f"s3://{bucket}/{item.key}"
492+
sub_results = self._find(
493+
sub_path, maxdepth=maxdepth - 1, withdirs=withdirs, **kwargs
494+
)
495+
result.extend(sub_results)
496+
497+
return result
498+
499+
# For unlimited depth, use the original approach (get all files at once)
414500
files = self._ls_dirs(path, prefix=prefix, delimiter="")
415501
if not files and key:
416502
try:
417503
files = [self.info(path)]
418504
except FileNotFoundError:
419505
files = []
506+
507+
# If withdirs is True, we need to derive directories from file paths
508+
if withdirs:
509+
files.extend(self._extract_parent_directories(files, bucket, key))
510+
511+
# Filter directories if withdirs is False (default)
512+
if withdirs is False or withdirs is None:
513+
files = [f for f in files if f.type != S3ObjectType.S3_OBJECT_TYPE_DIRECTORY]
514+
515+
return files
516+
517+
def find(
518+
self,
519+
path: str,
520+
maxdepth: Optional[int] = None,
521+
withdirs: Optional[bool] = None,
522+
detail: bool = False,
523+
**kwargs,
524+
) -> Union[Dict[str, S3Object], List[str]]:
525+
files = self._find(path=path, maxdepth=maxdepth, withdirs=withdirs, **kwargs)
420526
if detail:
421527
return {f.name: f for f in files}
422528
return [f.name for f in files]

0 commit comments

Comments
 (0)