Skip to content

Commit 972a4d0

Browse files
committed
fix all the mypy linting errors
1 parent 914f2b7 commit 972a4d0

27 files changed

Lines changed: 1428 additions & 1027 deletions

.github/workflows/python-package.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ jobs:
4242
4343
poetry config virtualenvs.create false
4444
45-
poetry export --with dev --extras ftp --format requirements.txt --output reqs.txt --without-hashes
45+
poetry export --with dev --extras dbc --format requirements.txt --output reqs.txt --without-hashes
4646
4747
pip install -r reqs.txt
4848
pip install -e ".[dbc]"
4949
50-
# pre-commit run --all-files
50+
pre-commit run --all-files
5151
5252
make test-pysus

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,3 +190,4 @@ cython_debug/
190190
# and can be added to the global gitignore or merged into this file. For a more nuclear
191191
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
192192
.idea/
193+
pyrightconfig.json

.pre-commit-config.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,21 @@ repos:
1414
rev: 24.2.0
1515
hooks:
1616
- id: black
17+
args: [--line-length=80]
1718
exclude: ^docs/
1819

1920
- repo: https://github.com/pycqa/isort
2021
rev: 5.13.2
2122
hooks:
2223
- id: isort
24+
args: [--profile=black, --line-length=80]
2325
exclude: ^.*/js/.*$
2426

2527
- repo: https://github.com/pycqa/flake8
2628
rev: 7.0.0
2729
hooks:
2830
- id: flake8
31+
args: [--max-line-length=80, --extend-ignore=E203]
2932
additional_dependencies: [
3033
'flake8-blind-except',
3134
'flake8-bugbear',
@@ -46,6 +49,7 @@ repos:
4649
'pydantic>=2.0.0',
4750
]
4851
args: [--ignore-missing-imports, --explicit-package-bases]
52+
exclude: ^docs/
4953

5054
- repo: https://github.com/asottile/pyupgrade
5155
rev: v3.15.0

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,8 @@ testpaths = [
100100
]
101101

102102
exclude = ["*.git", "docs/"]
103+
104+
[[tool.mypy.overrides]]
105+
module = "tests.*"
106+
disallow_untyped_defs = false
107+
check_untyped_defs = false

pysus/api/client.py

Lines changed: 87 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,18 @@
44
from pathlib import Path
55

66
from pysus import CACHEPATH
7-
from sqlalchemy import Column, DateTime, Enum, Integer, String, create_engine
8-
from sqlalchemy.orm import declarative_base, sessionmaker
7+
from sqlalchemy import DateTime, Enum, Integer, String, create_engine
8+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, sessionmaker
99

1010
from .dadosgov import DadosGovClient
1111
from .ducklake import DuckLakeClient
12-
from .ftp import FTPClient
13-
from .models import BaseLocalFile, BaseRemoteFile
1412
from .extensions import Parquet
13+
from .ftp import FTPClient
14+
from .models import BaseLocalFile, BaseRemoteFile, BaseTabularFile
15+
1516

16-
Base = declarative_base()
17+
class Base(DeclarativeBase):
18+
pass
1719

1820

1921
class DownloadStatus(enum.Enum):
@@ -26,22 +28,28 @@ class DownloadStatus(enum.Enum):
2628

2729
class LocalFileState(Base):
2830
__tablename__ = "local_file_state"
29-
path = Column(String, primary_key=True)
30-
remote_path = Column(String, nullable=False)
31-
client_name = Column(String, nullable=False)
32-
33-
year = Column(Integer, nullable=True)
34-
month = Column(Integer, nullable=True)
35-
state = Column(String, nullable=True)
36-
group = Column(String, nullable=True)
37-
38-
status = Column(Enum(DownloadStatus), default=DownloadStatus.PENDING)
39-
sha256 = Column(String, nullable=True)
40-
last_synced = Column(DateTime, default=datetime.utcnow)
31+
path: Mapped[str] = mapped_column(String, primary_key=True)
32+
remote_path: Mapped[str] = mapped_column(String, nullable=False)
33+
client_name: Mapped[str] = mapped_column(String, nullable=False)
34+
35+
year: Mapped[int | None] = mapped_column(Integer, nullable=True)
36+
month: Mapped[int | None] = mapped_column(Integer, nullable=True)
37+
state: Mapped[str | None] = mapped_column(String, nullable=True)
38+
group: Mapped[str | None] = mapped_column(String, nullable=True)
39+
40+
status: Mapped[DownloadStatus] = mapped_column(
41+
Enum(DownloadStatus),
42+
default=DownloadStatus.PENDING,
43+
)
44+
sha256: Mapped[str | None] = mapped_column(String, nullable=True)
45+
last_synced: Mapped[datetime] = mapped_column(
46+
DateTime,
47+
default=datetime.utcnow,
48+
)
4149

4250

4351
class PySUS:
44-
def __init__(self, db_path: str = CACHEPATH / "config.db"):
52+
def __init__(self, db_path: Path = CACHEPATH / "config.db"):
4553
db_path = Path(db_path)
4654
db_path.parent.mkdir(parents=True, exist_ok=True)
4755

@@ -55,12 +63,24 @@ def __init__(self, db_path: str = CACHEPATH / "config.db"):
5563
self._dadosgov: DadosGovClient | None = None
5664

5765
async def __aenter__(self):
58-
self._ducklake = DuckLakeClient(engine=self.engine)
66+
self._ducklake = DuckLakeClient()
5967
await self._ducklake._load_catalog()
60-
self._attach_client_catalog("ducklake", self._ducklake.catalog_path)
68+
self._attach_client_catalog(
69+
"ducklake", str(self._ducklake.catalog_path)
70+
)
6171
return self
6272

63-
async def get_dadosgov(self, access_token: str) -> DadosGovClient:
73+
async def get_ducklake(self) -> DuckLakeClient:
74+
if self._ducklake is None:
75+
self._ducklake = DuckLakeClient()
76+
await self._ducklake._load_catalog()
77+
self._attach_client_catalog(
78+
"ducklake",
79+
str(self._ducklake.catalog_path),
80+
)
81+
return self._ducklake
82+
83+
async def get_dadosgov(self, access_token: str | None) -> DadosGovClient:
6484
if self._dadosgov is None:
6585
self._dadosgov = DadosGovClient()
6686
await self._dadosgov.connect(token=access_token)
@@ -85,8 +105,8 @@ async def get_local_file(
85105
records = (
86106
session.query(LocalFileState)
87107
.filter_by(
88-
remote_path=remote_path,
89-
client_name=client_name,
108+
remote_path=str(remote_path),
109+
client_name=str(client_name),
90110
status=DownloadStatus.COMPLETED,
91111
)
92112
.all()
@@ -96,11 +116,11 @@ async def get_local_file(
96116
return None
97117

98118
parquet_version = next(
99-
(r for r in records if r.path.endswith(".parquet")), None
119+
(r for r in records if str(r.path).endswith(".parquet")), None
100120
)
101-
file = parquet_version or records[0]
121+
record = parquet_version or records[0]
102122

103-
return await ExtensionFactory.instantiate(file.path)
123+
return await ExtensionFactory.instantiate(str(record.path))
104124

105125
def _attach_client_catalog(self, name: str, path: str):
106126
abs_path = str(Path(path).absolute())
@@ -109,7 +129,9 @@ def _attach_client_catalog(self, name: str, path: str):
109129
existing = conn.exec_driver_sql(q, (abs_path,)).fetchone()
110130

111131
if not existing:
112-
conn.exec_driver_sql(f"ATTACH '{abs_path}' AS {name} (READ_ONLY)")
132+
conn.exec_driver_sql(
133+
f"ATTACH '{abs_path}' AS {name} (READ_ONLY)",
134+
)
113135

114136
async def __aexit__(self, exc_type, exc_val, exc_tb):
115137
if self._ducklake:
@@ -141,19 +163,23 @@ async def _update_state(
141163
remote_path: str,
142164
client_name: str,
143165
status: DownloadStatus,
144-
year: int = None,
145-
month: int = None,
146-
state: str = None,
147-
group: str = None,
166+
year: int | None = None,
167+
month: int | None = None,
168+
state: str | None = None,
169+
group: str | None = None,
148170
):
149171
with self.Session() as session:
150172
record = (
151-
session.query(LocalFileState).filter_by(path=str(local_path)).first()
173+
session.query(LocalFileState)
174+
.filter_by(
175+
path=str(local_path),
176+
)
177+
.first()
152178
)
153179
if not record:
154180
record = LocalFileState(
155181
path=str(local_path),
156-
remote_path=remote_path,
182+
remote_path=str(remote_path),
157183
client_name=client_name,
158184
year=year,
159185
month=month,
@@ -169,8 +195,8 @@ async def _update_state(
169195
async def download(
170196
self,
171197
file: BaseRemoteFile,
172-
token: str = None,
173-
callback: Callable = None,
198+
token: str | None = None,
199+
callback: Callable | None = None,
174200
):
175201
from pysus.api.extensions import ExtensionFactory
176202

@@ -185,24 +211,31 @@ async def download(
185211
local_path.parent.mkdir(parents=True, exist_ok=True)
186212

187213
await self._update_state(
188-
local_path, remote_path, client_name, DownloadStatus.DOWNLOADING
214+
local_path,
215+
str(remote_path),
216+
client_name,
217+
DownloadStatus.DOWNLOADING,
189218
)
190219

220+
client: DuckLakeClient | FTPClient | DadosGovClient
221+
191222
try:
192223
if client_name == "ducklake":
193-
await self._ducklake._download_file(file, local_path, callback)
224+
client = await self.get_ducklake()
194225
elif client_name == "ftp":
195226
client = await self.get_ftp()
196-
await client._download_file(file, local_path, callback)
197227
elif client_name == "dadosgov":
198228
client = await self.get_dadosgov(token)
199-
await client._download_file(file, local_path, callback)
200229
else:
201-
raise ValueError(f"No download logic for client: {client_name}")
230+
raise ValueError(
231+
f"No download logic for client: {client_name}",
232+
)
233+
234+
await client._download_file(file, local_path, callback)
202235

203236
await self._update_state(
204237
local_path=local_path,
205-
remote_path=remote_path,
238+
remote_path=str(remote_path),
206239
client_name=client_name,
207240
status=DownloadStatus.DOWNLOADING,
208241
year=file.year,
@@ -212,11 +245,13 @@ async def download(
212245
)
213246
return await ExtensionFactory.instantiate(local_path)
214247

215-
except Exception:
248+
except Exception as e: # noqa: B902
216249
await self._update_state(
217-
local_path, remote_path, client_name, DownloadStatus.FAILED
250+
local_path, str(remote_path), client_name, DownloadStatus.FAILED
218251
)
219-
raise
252+
raise RuntimeError(
253+
f"Unexpected error downloading {file.basename}: {e}",
254+
) from e
220255

221256
async def _delete_record(self, path: str):
222257
with self.Session() as session:
@@ -228,16 +263,16 @@ async def _delete_record(self, path: str):
228263
async def download_to_parquet(
229264
self,
230265
file: BaseRemoteFile,
231-
token: str = None,
232-
callback: Callable[[int, int], None] = None,
266+
token: str | None = None,
267+
callback: Callable[[int, int], None] | None = None,
233268
) -> Parquet:
234269
local_file = await self.download(
235270
file=file,
236271
token=token,
237272
callback=callback,
238273
)
239274

240-
if not hasattr(local_file, "to_parquet"):
275+
if not isinstance(local_file, BaseTabularFile):
241276
raise NotImplementedError(
242277
f"{local_file} can't be converted to Parquet",
243278
)
@@ -248,7 +283,7 @@ async def download_to_parquet(
248283

249284
await self._update_state(
250285
local_path=parquet_file.path,
251-
remote_path=file.path,
286+
remote_path=str(file.path),
252287
client_name=file.client.name.lower(),
253288
status=DownloadStatus.COMPLETED,
254289
year=file.year,
@@ -270,13 +305,14 @@ def get_local_hierarchy(self):
270305
hierarchy = {}
271306
for r in records:
272307
client = r.client_name.upper()
273-
274-
path_obj = Path(r.path)
308+
path_obj = Path(str(r.path))
275309
parts = path_obj.parts
276310

277311
dataset = parts[-2] if len(parts) > 2 else "Other"
312+
has_group = getattr(r, "group", None) is not None
313+
278314
if path_obj.is_file() and len(parts) > 3:
279-
dataset = parts[-2] if not r.group else parts[-3]
315+
dataset = parts[-2] if has_group else parts[-3]
280316

281317
client_dict = hierarchy.setdefault(client, {})
282318
ds_dict = client_dict.setdefault(dataset, {})

pysus/api/dadosgov/client.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
import pathlib
44
from collections.abc import Callable
55
from datetime import datetime
6-
from typing import Annotated, Any, Dict, List, Optional
6+
from typing import Annotated, Any, Optional
77

88
import httpx
99
from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, PrivateAttr
1010
from pysus import __version__
1111
from pysus.api.models import BaseRemoteClient, BaseRemoteFile
1212

13+
from .models import Dataset
14+
1315

1416
def to_datetime(value: Any) -> datetime | None:
1517
if not value or not isinstance(value, str) or "Indisponível" in value:
@@ -88,7 +90,7 @@ async def close(self) -> None:
8890
await self._client.aclose()
8991
self._client = None
9092

91-
async def datasets(self, **kwargs) -> list[ConjuntoDados]:
93+
async def datasets(self, **kwargs) -> list[Dataset]:
9294
from .databases import AVAILABLE_DATABASES
9395

9496
return [db_class(client=self) for db_class in AVAILABLE_DATABASES]
@@ -117,9 +119,7 @@ async def list_datasets(self, **kwargs) -> list[ConjuntoDados]:
117119
data = response.json()
118120
return [ConjuntoDados(**item, client=self) for item in data]
119121

120-
async def get_dataset(
121-
self, id: str, group_definitions: dict[str, str] | None = None
122-
) -> ConjuntoDados:
122+
async def get_dataset(self, id: str) -> ConjuntoDados:
123123
if self._client is None:
124124
raise ConnectionError(
125125
"Client not connected. Call login(token=...) first.",
@@ -131,7 +131,6 @@ async def get_dataset(
131131
return ConjuntoDados(
132132
**response.json(),
133133
client=self,
134-
group_definitions=group_definitions or {},
135134
)
136135

137136
async def _download_file(
@@ -162,7 +161,9 @@ class Recurso(BaseModel):
162161
title: str = Field(alias="titulo")
163162
url: str = Field(alias="link")
164163
api_size: int = Field(alias="tamanho")
165-
last_modified: DateTime = Field(None, alias="dataUltimaAtualizacaoArquivo")
164+
last_modified: datetime | None = Field(
165+
None, alias="dataUltimaAtualizacaoArquivo"
166+
)
166167
file_name: str | None = Field(None, alias="nomeArquivo")
167168

168169
async def get_size(self) -> int:
@@ -181,6 +182,7 @@ async def get_size(self) -> int:
181182

182183
class ConjuntoDados(BaseModel):
183184
model_config = ConfigDict(populate_by_name=True)
185+
client: BaseRemoteClient | None = None
184186

185187
id: str
186188
title: str = Field(alias="titulo")

0 commit comments

Comments
 (0)