Skip to content

Commit ef82807

Browse files
committed
fix(parquet): include parsings to parquet reading
1 parent 0920e27 commit ef82807

5 files changed

Lines changed: 138 additions & 10 deletions

File tree

pysus/api/client.py

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
from pathlib import Path
1111
from typing import TYPE_CHECKING, Literal
1212

13+
import anyio
1314
import duckdb
15+
import pandas as pd
1416
from pysus import CACHEPATH
1517
from sqlalchemy import DateTime, Enum, Integer, String, create_engine
1618
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, sessionmaker
@@ -235,8 +237,17 @@ async def download(
235237
file: BaseRemoteFile,
236238
token: str | None = None,
237239
callback: Callable | None = None,
240+
timeout: float | None = None,
238241
) -> BaseLocalFile:
239-
"""Download a remote file and return a local file handle."""
242+
"""Download a remote file and return a local file handle.
243+
244+
Parameters
245+
----------
246+
timeout : float | None
247+
Maximum seconds to wait for the download. ``None`` (default) means
248+
no timeout – use this when the socket-level timeout on the
249+
underlying client is sufficient.
250+
"""
240251

241252
from pysus.api.extensions import ExtensionFactory
242253

@@ -271,7 +282,11 @@ async def download(
271282
f"No download logic for client: {client_name}",
272283
)
273284

274-
await client._download_file(file, local_path, callback)
285+
if timeout is not None:
286+
with anyio.fail_after(timeout):
287+
await client._download_file(file, local_path, callback)
288+
else:
289+
await client._download_file(file, local_path, callback)
275290

276291
await self._update_state(
277292
local_path=local_path,
@@ -311,18 +326,22 @@ async def download_to_parquet(
311326
file: BaseRemoteFile,
312327
token: str | None = None,
313328
callback: Callable[[int, int], None] | None = None,
329+
timeout: float | None = None,
330+
add_dv: bool = True,
314331
) -> Parquet:
315332
"""Download a file and convert it to Parquet format."""
316333

317334
local_file = await self.download(
318335
file=file,
319336
token=token,
320337
callback=callback,
338+
timeout=timeout,
321339
)
322340

323341
if hasattr(local_file, "to_parquet"):
324342
original_path = local_file.path
325343
parquet_file = await local_file.to_parquet(callback=callback)
344+
parquet_file.add_dv = add_dv
326345

327346
await self._update_state(
328347
local_path=parquet_file.path,
@@ -346,7 +365,9 @@ async def download_to_parquet(
346365
)
347366

348367
def get_local_hierarchy(self):
349-
"""Build a nested dict of cached files grouped by client and dataset."""
368+
"""
369+
Build a nested dict of cached files grouped by client and dataset.
370+
"""
350371

351372
with self.Session() as session:
352373
records = session.query(LocalFileState).all()
@@ -414,8 +435,20 @@ def read_parquet(
414435
paths: list[Path],
415436
sql: str | None = None,
416437
mode: Literal["union", "intersection", "strict"] = "union",
417-
) -> "DuckDBPyConnection":
418-
"""Read Parquet files with optional schema handling and SQL filter."""
438+
add_dv: bool = True,
439+
) -> "DuckDBPyConnection | pd.DataFrame":
440+
"""Read Parquet files with optional schema handling and SQL filter.
441+
442+
Parameters
443+
----------
444+
add_dv : bool
445+
When True, automatically applies the IBGE verification digit to
446+
municipality code columns. If there are matching columns, a
447+
DataFrame is returned instead of a DuckDBPyConnection.
448+
"""
449+
450+
from pysus.api.utils import add_dv as _add_dv_fn
451+
from pysus.api.utils import is_geocode_column
419452

420453
if not paths:
421454
raise ValueError("No paths provided")
@@ -452,8 +485,7 @@ def get_columns(path: Path) -> set[tuple[str, str]]:
452485
else:
453486
paths_str = ", ".join(f"'{p}'" for p in paths)
454487
query = (
455-
f"SELECT * FROM read_parquet([{paths_str}], "
456-
"union_by_name=True)"
488+
f"SELECT * FROM read_parquet([{paths_str}], union_by_name=True)"
457489
)
458490

459491
if sql:
@@ -462,4 +494,29 @@ def get_columns(path: Path) -> set[tuple[str, str]]:
462494
else:
463495
query = f"SELECT {sql} FROM ({query}) AS t"
464496

497+
base = duckdb.execute(query)
498+
499+
if not add_dv:
500+
return base
501+
502+
geocode_cols = [
503+
col[0] for col in base.description if is_geocode_column(col[0])
504+
]
505+
if not geocode_cols:
506+
return base
507+
508+
duckdb.create_function(
509+
"__pysus_add_dv",
510+
_add_dv_fn,
511+
null_handling="special",
512+
)
513+
selects = [
514+
(
515+
f'__pysus_add_dv("{c[0]}") AS "{c[0]}"'
516+
if c[0] in geocode_cols
517+
else f'"{c[0]}"'
518+
)
519+
for c in base.description
520+
]
521+
query = f"SELECT {', '.join(selects)} FROM ({query}) AS _t"
465522
return duckdb.execute(query)

pysus/api/extensions.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ class Parquet(BaseTabularFile):
188188
"""Represents a Parquet file with optional date and integer type parsing."""
189189

190190
type: FileType = Field("PARQUET")
191+
add_dv: bool = True
191192

192193
@property
193194
def schema(self) -> pa.Schema:
@@ -204,12 +205,26 @@ def rows(self) -> int:
204205
"""Return the number of rows from the Parquet metadata."""
205206
return pq.read_metadata(self.path).num_rows
206207

208+
@staticmethod
209+
def _apply_add_dv(df: pd.DataFrame) -> pd.DataFrame:
210+
"""Apply the IBGE verification digit to geocode columns in-place."""
211+
from pysus.api.utils import add_dv, is_geocode_column
212+
213+
geocode_cols = [c for c in df.columns if is_geocode_column(c)]
214+
for col in geocode_cols:
215+
df[col] = df[col].astype(str).apply(add_dv)
216+
return df
217+
207218
async def load(self, parse: bool = True) -> pd.DataFrame:
208219
"""Read the entire Parquet file into a DataFrame."""
209220

210221
def _load():
211222
df = pd.read_parquet(self.path, engine="pyarrow")
212-
return self.parse_dftypes(df) if parse else df
223+
if parse:
224+
df = self.parse_dftypes(df)
225+
if self.add_dv:
226+
df = self._apply_add_dv(df)
227+
return df
213228

214229
return await to_thread.run_sync(_load)
215230

@@ -226,6 +241,8 @@ async def stream(
226241
df = batch.to_pandas()
227242
if parse:
228243
df = self.parse_dftypes(df)
244+
if self.add_dv:
245+
df = self._apply_add_dv(df)
229246
yield df
230247
await asyncio.sleep(0)
231248

pysus/api/ftp/client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class FTP(BaseRemoteClient):
4343
"""Async FTP client for navigating and downloading DATASUS data."""
4444

4545
host: str = "ftp.datasus.gov.br"
46+
timeout: int = 60
4647

4748
_ftp: FTPLib | None = PrivateAttr(default=None)
4849

@@ -77,7 +78,7 @@ async def connect(self) -> None:
7778

7879
def _connect():
7980
if self.ftp is None:
80-
self._ftp = FTPLib(self.host)
81+
self._ftp = FTPLib(self.host, timeout=self.timeout)
8182
self.ftp.login()
8283

8384
await to_thread.run_sync(_connect)

pysus/api/utils.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
GEOCODE_PREFIXES = (
2+
"ID_MUNICIP",
3+
"ID_MN_RESI",
4+
"ID_MUNI_RE",
5+
"MUN_",
6+
"COD_MUN_",
7+
"CO_MUN_",
8+
"ID_MUNI_AT",
9+
"ID_MUNIC_",
10+
)
11+
12+
13+
def is_geocode_column(name: str) -> bool:
14+
"""Check if a column name corresponds to an IBGE municipality code."""
15+
upper = name.upper()
16+
return any(upper.startswith(p) for p in GEOCODE_PREFIXES)
17+
18+
19+
def add_dv(geocode: str) -> str:
20+
if not geocode or not str(geocode).isdigit():
21+
return geocode
22+
23+
miscalculated = {
24+
"2201911": "2201919",
25+
"2201986": "2201988",
26+
"2202257": "2202251",
27+
"2611531": "2611533",
28+
"3117835": "3117836",
29+
"3152139": "3152131",
30+
"4305876": "4305871",
31+
"5203963": "5203962",
32+
"5203930": "5203939",
33+
}
34+
35+
if len(str(geocode)) == 7:
36+
return miscalculated.get(str(geocode), geocode)
37+
38+
if len(str(geocode)) == 6:
39+
weight = [1, 2, 1, 2, 1, 2]
40+
total = sum(
41+
sum(divmod(int(d) * w, 10))
42+
for d, w in zip(
43+
str(geocode),
44+
weight,
45+
)
46+
)
47+
dv = 0 if total % 10 == 0 else 10 - (total % 10)
48+
code = str(geocode) + str(dv)
49+
return miscalculated.get(code, code)
50+
51+
return geocode

pysus/tests/api/ftp/test_client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ async def test_connect_and_login(ftp_client):
4848
mock_instance = mock_ftplib.return_value
4949
await ftp_client.login()
5050

51-
mock_ftplib.assert_called_once_with(ftp_client.host)
51+
mock_ftplib.assert_called_once_with(
52+
ftp_client.host, timeout=ftp_client.timeout
53+
)
5254
mock_instance.login.assert_called_once()
5355

5456

0 commit comments

Comments
 (0)