44from pathlib import Path
55
66from pysus import CACHEPATH
7- from sqlalchemy import Column , DateTime , Enum , Integer , String , create_engine
8- from sqlalchemy .orm import declarative_base , sessionmaker
7+ from sqlalchemy import DateTime , Enum , Integer , String , create_engine
8+ from sqlalchemy .orm import DeclarativeBase , Mapped , mapped_column , sessionmaker
99
1010from .dadosgov import DadosGovClient
1111from .ducklake import DuckLakeClient
12- from .ftp import FTPClient
13- from .models import BaseLocalFile , BaseRemoteFile
1412from .extensions import Parquet
13+ from .ftp import FTPClient
14+ from .models import BaseLocalFile , BaseRemoteFile , BaseTabularFile
15+
1516
16- Base = declarative_base ()
17+ class Base (DeclarativeBase ):
18+ pass
1719
1820
1921class DownloadStatus (enum .Enum ):
@@ -26,22 +28,28 @@ class DownloadStatus(enum.Enum):
2628
2729class LocalFileState (Base ):
2830 __tablename__ = "local_file_state"
29- path = Column (String , primary_key = True )
30- remote_path = Column (String , nullable = False )
31- client_name = Column (String , nullable = False )
32-
33- year = Column (Integer , nullable = True )
34- month = Column (Integer , nullable = True )
35- state = Column (String , nullable = True )
36- group = Column (String , nullable = True )
37-
38- status = Column (Enum (DownloadStatus ), default = DownloadStatus .PENDING )
39- sha256 = Column (String , nullable = True )
40- last_synced = Column (DateTime , default = datetime .utcnow )
31+ path : Mapped [str ] = mapped_column (String , primary_key = True )
32+ remote_path : Mapped [str ] = mapped_column (String , nullable = False )
33+ client_name : Mapped [str ] = mapped_column (String , nullable = False )
34+
35+ year : Mapped [int | None ] = mapped_column (Integer , nullable = True )
36+ month : Mapped [int | None ] = mapped_column (Integer , nullable = True )
37+ state : Mapped [str | None ] = mapped_column (String , nullable = True )
38+ group : Mapped [str | None ] = mapped_column (String , nullable = True )
39+
40+ status : Mapped [DownloadStatus ] = mapped_column (
41+ Enum (DownloadStatus ),
42+ default = DownloadStatus .PENDING ,
43+ )
44+ sha256 : Mapped [str | None ] = mapped_column (String , nullable = True )
45+ last_synced : Mapped [datetime ] = mapped_column (
46+ DateTime ,
47+ default = datetime .utcnow ,
48+ )
4149
4250
4351class PySUS :
44- def __init__ (self , db_path : str = CACHEPATH / "config.db" ):
52+ def __init__ (self , db_path : Path = CACHEPATH / "config.db" ):
4553 db_path = Path (db_path )
4654 db_path .parent .mkdir (parents = True , exist_ok = True )
4755
@@ -55,12 +63,24 @@ def __init__(self, db_path: str = CACHEPATH / "config.db"):
5563 self ._dadosgov : DadosGovClient | None = None
5664
5765 async def __aenter__ (self ):
58- self ._ducklake = DuckLakeClient (engine = self . engine )
66+ self ._ducklake = DuckLakeClient ()
5967 await self ._ducklake ._load_catalog ()
60- self ._attach_client_catalog ("ducklake" , self ._ducklake .catalog_path )
68+ self ._attach_client_catalog (
69+ "ducklake" , str (self ._ducklake .catalog_path )
70+ )
6171 return self
6272
63- async def get_dadosgov (self , access_token : str ) -> DadosGovClient :
73+ async def get_ducklake (self ) -> DuckLakeClient :
74+ if self ._ducklake is None :
75+ self ._ducklake = DuckLakeClient ()
76+ await self ._ducklake ._load_catalog ()
77+ self ._attach_client_catalog (
78+ "ducklake" ,
79+ str (self ._ducklake .catalog_path ),
80+ )
81+ return self ._ducklake
82+
83+ async def get_dadosgov (self , access_token : str | None ) -> DadosGovClient :
6484 if self ._dadosgov is None :
6585 self ._dadosgov = DadosGovClient ()
6686 await self ._dadosgov .connect (token = access_token )
@@ -85,8 +105,8 @@ async def get_local_file(
85105 records = (
86106 session .query (LocalFileState )
87107 .filter_by (
88- remote_path = remote_path ,
89- client_name = client_name ,
108+ remote_path = str ( remote_path ) ,
109+ client_name = str ( client_name ) ,
90110 status = DownloadStatus .COMPLETED ,
91111 )
92112 .all ()
@@ -96,11 +116,11 @@ async def get_local_file(
96116 return None
97117
98118 parquet_version = next (
99- (r for r in records if r .path .endswith (".parquet" )), None
119+ (r for r in records if str ( r .path ) .endswith (".parquet" )), None
100120 )
101- file = parquet_version or records [0 ]
121+ record = parquet_version or records [0 ]
102122
103- return await ExtensionFactory .instantiate (file .path )
123+ return await ExtensionFactory .instantiate (str ( record .path ) )
104124
105125 def _attach_client_catalog (self , name : str , path : str ):
106126 abs_path = str (Path (path ).absolute ())
@@ -109,7 +129,9 @@ def _attach_client_catalog(self, name: str, path: str):
109129 existing = conn .exec_driver_sql (q , (abs_path ,)).fetchone ()
110130
111131 if not existing :
112- conn .exec_driver_sql (f"ATTACH '{ abs_path } ' AS { name } (READ_ONLY)" )
132+ conn .exec_driver_sql (
133+ f"ATTACH '{ abs_path } ' AS { name } (READ_ONLY)" ,
134+ )
113135
114136 async def __aexit__ (self , exc_type , exc_val , exc_tb ):
115137 if self ._ducklake :
@@ -141,19 +163,23 @@ async def _update_state(
141163 remote_path : str ,
142164 client_name : str ,
143165 status : DownloadStatus ,
144- year : int = None ,
145- month : int = None ,
146- state : str = None ,
147- group : str = None ,
166+ year : int | None = None ,
167+ month : int | None = None ,
168+ state : str | None = None ,
169+ group : str | None = None ,
148170 ):
149171 with self .Session () as session :
150172 record = (
151- session .query (LocalFileState ).filter_by (path = str (local_path )).first ()
173+ session .query (LocalFileState )
174+ .filter_by (
175+ path = str (local_path ),
176+ )
177+ .first ()
152178 )
153179 if not record :
154180 record = LocalFileState (
155181 path = str (local_path ),
156- remote_path = remote_path ,
182+ remote_path = str ( remote_path ) ,
157183 client_name = client_name ,
158184 year = year ,
159185 month = month ,
@@ -169,8 +195,8 @@ async def _update_state(
169195 async def download (
170196 self ,
171197 file : BaseRemoteFile ,
172- token : str = None ,
173- callback : Callable = None ,
198+ token : str | None = None ,
199+ callback : Callable | None = None ,
174200 ):
175201 from pysus .api .extensions import ExtensionFactory
176202
@@ -185,24 +211,31 @@ async def download(
185211 local_path .parent .mkdir (parents = True , exist_ok = True )
186212
187213 await self ._update_state (
188- local_path , remote_path , client_name , DownloadStatus .DOWNLOADING
214+ local_path ,
215+ str (remote_path ),
216+ client_name ,
217+ DownloadStatus .DOWNLOADING ,
189218 )
190219
220+ client : DuckLakeClient | FTPClient | DadosGovClient
221+
191222 try :
192223 if client_name == "ducklake" :
193- await self ._ducklake . _download_file ( file , local_path , callback )
224+ client = await self .get_ducklake ( )
194225 elif client_name == "ftp" :
195226 client = await self .get_ftp ()
196- await client ._download_file (file , local_path , callback )
197227 elif client_name == "dadosgov" :
198228 client = await self .get_dadosgov (token )
199- await client ._download_file (file , local_path , callback )
200229 else :
201- raise ValueError (f"No download logic for client: { client_name } " )
230+ raise ValueError (
231+ f"No download logic for client: { client_name } " ,
232+ )
233+
234+ await client ._download_file (file , local_path , callback )
202235
203236 await self ._update_state (
204237 local_path = local_path ,
205- remote_path = remote_path ,
238+ remote_path = str ( remote_path ) ,
206239 client_name = client_name ,
207240 status = DownloadStatus .DOWNLOADING ,
208241 year = file .year ,
@@ -212,11 +245,13 @@ async def download(
212245 )
213246 return await ExtensionFactory .instantiate (local_path )
214247
215- except Exception :
248+ except Exception as e : # noqa: B902
216249 await self ._update_state (
217- local_path , remote_path , client_name , DownloadStatus .FAILED
250+ local_path , str ( remote_path ) , client_name , DownloadStatus .FAILED
218251 )
219- raise
252+ raise RuntimeError (
253+ f"Unexpected error downloading { file .basename } : { e } " ,
254+ ) from e
220255
221256 async def _delete_record (self , path : str ):
222257 with self .Session () as session :
@@ -228,16 +263,16 @@ async def _delete_record(self, path: str):
228263 async def download_to_parquet (
229264 self ,
230265 file : BaseRemoteFile ,
231- token : str = None ,
232- callback : Callable [[int , int ], None ] = None ,
266+ token : str | None = None ,
267+ callback : Callable [[int , int ], None ] | None = None ,
233268 ) -> Parquet :
234269 local_file = await self .download (
235270 file = file ,
236271 token = token ,
237272 callback = callback ,
238273 )
239274
240- if not hasattr (local_file , "to_parquet" ):
275+ if not isinstance (local_file , BaseTabularFile ):
241276 raise NotImplementedError (
242277 f"{ local_file } can't be converted to Parquet" ,
243278 )
@@ -248,7 +283,7 @@ async def download_to_parquet(
248283
249284 await self ._update_state (
250285 local_path = parquet_file .path ,
251- remote_path = file .path ,
286+ remote_path = str ( file .path ) ,
252287 client_name = file .client .name .lower (),
253288 status = DownloadStatus .COMPLETED ,
254289 year = file .year ,
@@ -270,13 +305,14 @@ def get_local_hierarchy(self):
270305 hierarchy = {}
271306 for r in records :
272307 client = r .client_name .upper ()
273-
274- path_obj = Path (r .path )
308+ path_obj = Path (str (r .path ))
275309 parts = path_obj .parts
276310
277311 dataset = parts [- 2 ] if len (parts ) > 2 else "Other"
312+ has_group = getattr (r , "group" , None ) is not None
313+
278314 if path_obj .is_file () and len (parts ) > 3 :
279- dataset = parts [- 2 ] if not r . group else parts [- 3 ]
315+ dataset = parts [- 2 ] if has_group else parts [- 3 ]
280316
281317 client_dict = hierarchy .setdefault (client , {})
282318 ds_dict = client_dict .setdefault (dataset , {})
0 commit comments