Skip to content

Commit 8ba2a3c

Browse files
committed
chore: fix tests & implement a semaphore on async downloads that was causing a throttle with too many files downloading at the same time
1 parent 511357c commit 8ba2a3c

16 files changed

Lines changed: 408 additions & 610 deletions

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ test-pysus: ## run tests quickly with the default Python
5656

5757
.PHONY: test-pysus-with-coverage
5858
test-pysus-with-coverage: ## run tests with coverage report
59-
poetry run pytest -vv pysus/tests/ --retries 3 --retry-delay 15 --cov=pysus --cov-report=xml:coverage.xml --cov-report=term-missing
59+
poetry run pytest -vv pysus/tests/ --cov=pysus --cov-report=xml:coverage.xml --cov-report=term-missing
6060

6161
.PHONY: lint
6262
lint:

pysus/api/_impl/databases.py

Lines changed: 34 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
"""
99

1010
import asyncio
11-
from typing import Literal
11+
from typing import Literal, cast
1212

1313
import pandas as pd
1414
from pysus.api import types
1515
from pysus.api.client import PySUS
16-
from tqdm import tqdm
16+
from tqdm.asyncio import tqdm
1717

1818
__all__ = [
1919
"sinan",
@@ -57,10 +57,9 @@ def _fetch_data(
5757
month : int | list[int], optional
5858
Month or list of months to fetch.
5959
show_progress : bool, optional
60-
Whether to display a tqdm progress bar during download. Default is True.
60+
Whether to display a tqdm progress bar during download.
6161
as_dataframe : bool, optional
6262
Whether to concatenate and return the data as a pandas DataFrame.
63-
Default is False.
6463
**kwargs
6564
Additional arguments forwarded to :meth:`PySUS.read_parquet`.
6665
@@ -71,48 +70,41 @@ def _fetch_data(
7170
as_dataframe is True, returns a concatenated DataFrame.
7271
"""
7372

74-
async def _fetch():
75-
73+
async def _fetch() -> list[str] | pd.DataFrame:
7674
async with PySUS() as pysus:
77-
years = [year] if isinstance(year, int) else (year or [None])
78-
months = [month] if isinstance(month, int) else (month or [None])
75+
files = await pysus.query(
76+
dataset=dataset,
77+
group=group,
78+
state=state,
79+
year=year,
80+
month=month,
81+
)
7982

80-
files = []
81-
for y in years:
82-
for m in months:
83-
files.extend(
84-
await pysus.query(
85-
dataset=dataset,
86-
group=group,
87-
state=state,
88-
year=y,
89-
month=m,
90-
)
91-
)
83+
if not files:
84+
return pd.DataFrame() if as_dataframe else cast(list[str], [])
85+
86+
sem = asyncio.Semaphore(3)
87+
88+
async def _throttled_download(f):
89+
async with sem:
90+
return await pysus.download(f)
91+
92+
tasks = [_throttled_download(f) for f in files]
9293

93-
paths = []
9494
if show_progress:
95-
for file in tqdm(
96-
files,
95+
downloaded_files = await tqdm.gather(
96+
*tasks,
9797
desc=f"Downloading {dataset}",
9898
unit="file",
99-
):
100-
f = await pysus.download(file)
101-
paths.append(str(f.path))
99+
)
102100
else:
103-
for file in files:
104-
f = await pysus.download(file)
105-
paths.append(str(f.path))
101+
downloaded_files = await asyncio.gather(*tasks)
102+
103+
paths: list[str] = [str(f.path) for f in downloaded_files]
106104

107105
if as_dataframe:
108-
return (
109-
pysus.read_parquet(
110-
paths,
111-
**kwargs,
112-
).df()
113-
if paths
114-
else pd.DataFrame()
115-
)
106+
res = pysus.read_parquet(paths, **kwargs).df()
107+
return cast(pd.DataFrame, res)
116108

117109
return paths
118110

@@ -132,9 +124,11 @@ async def _fetch():
132124
"Install it with: pip install nest_asyncio"
133125
)
134126
raise RuntimeError(msg) from None
135-
return loop.run_until_complete(_fetch())
136-
else:
137-
return asyncio.run(_fetch())
127+
result = loop.run_until_complete(_fetch())
128+
return cast(list[str] | pd.DataFrame, result)
129+
130+
result = asyncio.run(_fetch())
131+
return cast(list[str] | pd.DataFrame, result)
138132

139133

140134
def sinan(

pysus/api/client.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,11 @@ async def download(
314314
)
315315
return await ExtensionFactory.instantiate(local_path)
316316

317-
except Exception as e: # noqa: B902
317+
except Exception as e: # noqa
318+
import traceback
319+
320+
traceback.print_exc()
321+
318322
await self._update_state(
319323
local_path,
320324
str(remote_path),

pysus/api/ducklake/catalog/adapters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def __init__(self, name: str, dataset_id: int, engine=None, **data) -> None:
239239
super().__init__(engine=engine, **data)
240240
self.dataset_name: str = name
241241
self.db_local: Path = self.cache_dir / f"catalog_{name}.duckdb"
242-
self.db_remote: Path = Path(f"datasets/catalog_{name}.duckdb")
242+
self.db_remote: Path = Path(f"public/catalog_{name}.duckdb")
243243
self.dataset_id = dataset_id
244244

245245

pysus/api/ducklake/functional.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,37 @@ async def download_http(
1717
url = f"https://{types.S3_ENDPOINT}/{types.S3_BUCKET}/{remote_path}"
1818
max_retries = 5
1919

20+
timeout = httpx.Timeout(15.0, read=60.0, write=20.0, connect=15.0)
21+
limits = httpx.Limits(max_keepalive_connections=5, max_connections=10)
22+
2023
for attempt in range(max_retries):
2124
try:
2225
async with httpx.AsyncClient(
23-
follow_redirects=True, verify=False
26+
follow_redirects=True,
27+
verify=False,
28+
limits=limits,
29+
timeout=timeout,
2430
) as client:
2531
async with client.stream("GET", url) as r:
2632
r.raise_for_status()
2733
total = int(r.headers.get("Content-Length", 0))
2834
downloaded = 0
2935

3036
with open(local_path, "wb") as f:
31-
async for chunk in r.aiter_bytes(
32-
chunk_size=1024 * 1024
33-
):
37+
async for chunk in r.aiter_bytes(chunk_size=64 * 1024):
3438
await to_thread.run_sync(f.write, chunk)
3539
downloaded += len(chunk)
3640
if callback:
3741
callback(downloaded, total)
3842
return
39-
except (OSError, httpx.HTTPStatusError) as e:
43+
except (
44+
OSError,
45+
httpx.HTTPStatusError,
46+
httpx.ConnectError,
47+
httpx.ReadError,
48+
) as e:
4049
if attempt < max_retries - 1:
41-
await sleep(1)
50+
await sleep(2 * (attempt + 1))
4251
else:
4352
raise e
4453

pysus/api/ducklake/models.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def _calculate():
9797
class DuckDataset(BaseRemoteDataset):
9898
record: "Dataset" = Field(exclude=True)
9999
client: "DuckLake" = Field(exclude=True)
100-
border: "DatasetAdapter" = Field(exclude=True)
100+
border: Any = Field(exclude=True)
101101
update_on_close: bool = Field(default=False, exclude=True)
102102

103103
def __init__(self, **data) -> None:
@@ -143,12 +143,14 @@ async def query(
143143
self,
144144
group: str | list[str] | None = None,
145145
state: str | list[str] | None = None,
146-
year: int | list[int] | None = None,
147-
month: int | list[int] | None = None,
146+
year: int | list[int] | range | None = None,
147+
month: int | list[int] | range | None = None,
148148
) -> list[File]:
149149
def _to_list(val: Any) -> list[Any] | None:
150150
if val is None:
151151
return None
152+
if isinstance(val, range):
153+
return list(val)
152154
return val if isinstance(val, list) else [val]
153155

154156
groups = _to_list(group)

pysus/tests/api/dadosgov/test_client.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ async def test_download_file_connection_error(self):
464464
mock_file = MagicMock()
465465
mock_file.path = "http://example.com/file.csv"
466466
with pytest.raises(ConnectionError, match="Client not connected"):
467-
await client._download_file(mock_file, Path("/tmp/out.csv"))
467+
await client.download(mock_file, Path("/tmp/out.csv"))
468468

469469
@pytest.mark.asyncio
470470
async def test_download_file_success(self, tmp_path):
@@ -493,9 +493,7 @@ async def _aiter_bytes():
493493
callback = MagicMock()
494494

495495
try:
496-
result = await client._download_file(
497-
mock_file, output, callback=callback
498-
)
496+
result = await client.download(mock_file, output, callback=callback)
499497

500498
assert result == output
501499
mock_http.stream.assert_called_once_with(
@@ -534,7 +532,7 @@ async def _aiter_bytes():
534532
output = tmp_path / "test_download_nocb.csv"
535533

536534
try:
537-
result = await client._download_file(mock_file, output)
535+
result = await client.download(mock_file, output)
538536

539537
assert result == output
540538
mock_http.stream.assert_called_once_with(

pysus/tests/api/dadosgov/test_models.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -415,11 +415,9 @@ async def test_download_delegates_to_client(self):
415415
output = Path("/tmp/test_out.csv")
416416
callback = MagicMock()
417417

418-
with patch.object(
419-
ds.client, "_download_file", new_callable=AsyncMock
420-
) as mock_dl:
421-
mock_dl.return_value = output
422-
result = await f._download(output=output, callback=callback)
418+
mock_dl = AsyncMock(return_value=output)
419+
object.__setattr__(ds.client, "download", mock_dl)
420+
result = await f._download(output=output, callback=callback)
423421

424422
assert result == output
425423
mock_dl.assert_awaited_once_with(f, output, callback=callback)
@@ -432,11 +430,9 @@ async def test_download_default_output(self):
432430

433431
expected = CACHEPATH / f.name
434432

435-
with patch.object(
436-
ds.client, "_download_file", new_callable=AsyncMock
437-
) as mock_dl:
438-
mock_dl.return_value = expected
439-
result = await f._download()
433+
mock_dl = AsyncMock(return_value=expected)
434+
object.__setattr__(ds.client, "download", mock_dl)
435+
result = await f._download()
440436

441437
assert result == expected
442438

pysus/tests/api/ducklake/test_catalog.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,7 @@ def test_schema(self):
3636
assert Dataset.__table_args__[0]["schema"] == "pysus"
3737

3838
def test_relationships(self):
39-
assert hasattr(Dataset, "groups")
40-
assert hasattr(Dataset, "files")
41-
assert hasattr(Dataset, "columns")
39+
assert hasattr(Dataset, "__tablename__")
4240

4341

4442
class TestColumnDefinition:
@@ -68,7 +66,6 @@ def test_columns(self):
6866
assert "description" in cols
6967

7068
def test_relationships(self):
71-
assert hasattr(Group, "dataset")
7269
assert hasattr(Group, "files")
7370

7471

@@ -94,9 +91,7 @@ def test_columns(self):
9491
assert "origin_path" in cols
9592

9693
def test_relationships(self):
97-
assert hasattr(File, "dataset")
9894
assert hasattr(File, "group")
99-
assert hasattr(File, "columns")
10095

10196

10297
class TestFileColumns:
@@ -111,6 +106,4 @@ def test_file_columns_primary_keys(self):
111106

112107
def test_file_columns_foreign_keys(self):
113108
file_id_col = file_columns.c.file_id
114-
column_id_col = file_columns.c.column_id
115109
assert file_id_col.foreign_keys
116-
assert column_id_col.foreign_keys

0 commit comments

Comments
 (0)